Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, Dataset | |
| import torchvision | |
| import torchvision.transforms.functional as TF | |
| import json | |
| import os | |
| from PIL import Image | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import yaml | |
| import random | |
| import gc | |
| from utils import * | |
| from models import instructir | |
| from text.models import LanguageModel, LMHead | |
| from test import test_model | |
| def seed_everything(SEED=42): | |
| random.seed(SEED) | |
| np.random.seed(SEED) | |
| torch.manual_seed(SEED) | |
| torch.cuda.manual_seed(SEED) | |
| torch.cuda.manual_seed_all(SEED) | |
| torch.backends.cudnn.benchmark = True | |
| if __name__=="__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config', type=str, default='configs/eval5d.yml', help='Path to config file') | |
| parser.add_argument('--model', type=str, default="models/im_instructir-7d.pt", help='Path to the image model weights') | |
| parser.add_argument('--lm', type=str, default="models/lm_instructir-7d.pt", help='Path to the language model weights') | |
| parser.add_argument('--promptify', type=str, default="simple_augment") | |
| parser.add_argument('--device', type=int, default=0, help="GPU device") | |
| parser.add_argument('--debug', action='store_true', help="Debug mode") | |
| parser.add_argument('--save', type=str, default='results/', help="Path to save the resultant images") | |
| args = parser.parse_args() | |
| SEED=42 | |
| seed_everything(SEED=SEED) | |
| torch.backends.cudnn.deterministic = True | |
| GPU = args.device | |
| DEBUG = args.debug | |
| MODEL_NAME = args.model | |
| CONFIG = args.config | |
| LM_MODEL = args.lm | |
| SAVE_PATH = args.save | |
| print ('CUDA GPU available: ', torch.cuda.is_available()) | |
| torch.cuda.set_device(f'cuda:{GPU}') | |
| device = torch.device(f'cuda:{GPU}' if torch.cuda.is_available() else "cpu") | |
| print ('CUDA visible devices: ' + str(torch.cuda.device_count())) | |
| print ('CUDA current device: ', torch.cuda.current_device(), torch.cuda.get_device_name(torch.cuda.current_device())) | |
| # parse config file | |
| with open(os.path.join(CONFIG), "r") as f: | |
| config = yaml.safe_load(f) | |
| cfg = dict2namespace(config) | |
| print (20*"****") | |
| print ("EVALUATION") | |
| print (MODEL_NAME, LM_MODEL, device, DEBUG, CONFIG, args.promptify) | |
| print (20*"****") | |
| ################### TESTING DATASET | |
| TESTSETS = [] | |
| dn_testsets = [] | |
| rain_testsets = [] | |
| # Denoising | |
| try: | |
| for testset in cfg.test.dn_datasets: | |
| for sigma in cfg.test.dn_sigmas: | |
| noisy_testpath = os.path.join(cfg.test.dn_datapath, testset+ f"_{sigma}") | |
| clean_testpath = os.path.join(cfg.test.dn_datapath, testset) | |
| #print (clean_testpath, noisy_testpath) | |
| dn_testsets.append([clean_testpath, noisy_testpath]) | |
| except: | |
| dn_testsets = [] | |
| # RAIN | |
| try: | |
| for noisy_testpath, clean_testpath in zip(cfg.test.rain_inputs, cfg.test.rain_targets): | |
| rain_testsets.append([clean_testpath, noisy_testpath]) | |
| except: | |
| rain_testsets = [] | |
| # HAZE | |
| try: | |
| haze_testsets = [[cfg.test.haze_targets, cfg.test.haze_inputs]] | |
| except: | |
| haze_testsets = [] | |
| # BLUR | |
| try: | |
| blur_testsets = [[cfg.test.gopro_targets, cfg.test.gopro_inputs]] | |
| except: | |
| blur_testsets = [] | |
| # LOL | |
| try: | |
| lol_testsets = [[cfg.test.lol_targets, cfg.test.lol_inputs]] | |
| except: | |
| lol_testsets = [] | |
| # MIT5K | |
| try: | |
| mit_testsets = [[cfg.test.mit_targets, cfg.test.mit_inputs]] | |
| except: | |
| mit_testsets = [] | |
| TESTSETS += dn_testsets | |
| TESTSETS += rain_testsets | |
| TESTSETS += haze_testsets | |
| TESTSETS += blur_testsets | |
| TESTSETS += lol_testsets | |
| TESTSETS += mit_testsets | |
| # print ("Tests:", TESTSETS) | |
| print ("TOTAL TESTSET:", len(TESTSETS)) | |
| print (20 * "----") | |
| ################### RESTORATION MODEL | |
| print ("Creating InstructIR") | |
| model = instructir.create_model(input_channels =cfg.model.in_ch, width=cfg.model.width, enc_blks = cfg.model.enc_blks, | |
| middle_blk_num = cfg.model.middle_blk_num, dec_blks = cfg.model.dec_blks, txtdim=cfg.model.textdim) | |
| ################### LOAD IMAGE MODEL | |
| assert MODEL_NAME, "Model weights required for evaluation" | |
| print ("IMAGE MODEL CKPT:", MODEL_NAME) | |
| model.load_state_dict(torch.load(MODEL_NAME), strict=True) | |
| model = model.to(device) | |
| nparams = count_params (model) | |
| print ("Loaded weights!", nparams / 1e6) | |
| ################### LANGUAGE MODEL | |
| try: | |
| PROMPT_DB = cfg.llm.text_db | |
| except: | |
| PROMPT_DB = None | |
| if cfg.model.use_text: | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # Initialize the LanguageModel class | |
| LMODEL = cfg.llm.model | |
| language_model = LanguageModel(model=LMODEL) | |
| lm_head = LMHead(embedding_dim=cfg.llm.model_dim, hidden_dim=cfg.llm.embd_dim, num_classes=cfg.llm.nclasses) | |
| lm_head = lm_head.to(device) | |
| lm_nparams = count_params (lm_head) | |
| print ("LMHEAD MODEL CKPT:", LM_MODEL) | |
| lm_head.load_state_dict(torch.load(LM_MODEL), strict=True) | |
| print ("Loaded weights!") | |
| else: | |
| LMODEL = None | |
| language_model = None | |
| lm_head = None | |
| lm_nparams = 0 | |
| print (20 * "----") | |
| ################### TESTING !! | |
| from datasets import RefDegImage, augment_prompt, create_testsets | |
| if args.promptify == "simple_augment": | |
| promptify = augment_prompt | |
| elif args.promptify == "chatgpt": | |
| prompts = json.load(open(cfg.llm.text_db)) | |
| def promptify(deg): | |
| return np.random.choice(prompts[deg]) | |
| else: | |
| def promptify(deg): | |
| return args.promptify | |
| torch.cuda.empty_cache() | |
| torch.cuda.reset_peak_memory_stats() | |
| test_datasets = create_testsets(TESTSETS, debug=True) | |
| test_model (model, language_model, lm_head, test_datasets, device, promptify, savepath=SAVE_PATH) | |