from envs.env_dataloader import create_dataloaders import torchvision.transforms as transforms from torchvision.transforms import v2 from torchmetrics.image import StructuralSimilarityIndexMeasure from envs.new_edit_photo import PhotoEditor from sac.sac_inference import InferenceAgent import yaml from envs.photo_env import PhotoEnhancementEnvTest import numpy as np import argparse import logging import os from pathlib import Path from tqdm import tqdm import random import matplotlib.pyplot as plt import torch class Config(object): def __init__(self,dictionary): self.__dict__.update(dictionary) def load_preprocessor_agent(preprocessor_agent_path,device): current_dir = Path(__file__).parent.absolute() with open(os.path.join(preprocessor_agent_path,"configs/sac_config.yaml")) as f: sac_config_dict =yaml.load(f, Loader=yaml.FullLoader) with open(os.path.join(preprocessor_agent_path,"configs/env_config.yaml")) as f: env_config_dict =yaml.load(f, Loader=yaml.FullLoader) with open(os.path.join(current_dir,"../configs/inference_config.yaml")) as f: inf_config_dict =yaml.load(f, Loader=yaml.FullLoader) inference_config = Config(inf_config_dict) sac_config = Config(sac_config_dict) env_config = Config(env_config_dict) inference_env = PhotoEnhancementEnvTest( batch_size=env_config.train_batch_size, imsize=env_config.imsize, training_mode=False, done_threshold=env_config.threshold_psnr, edit_sliders=env_config.sliders_to_use, features_size=env_config.features_size, discretize=env_config.discretize, discretize_step= env_config.discretize_step, use_txt_features=env_config.use_txt_features if hasattr(env_config,'use_txt_features') else False, augment_data=False, pre_encoding_device=device, pre_load_images=False, logger=None)# useless just to get the action space size for the Networks and whether to use txt features or not preprocessor_photo_editor = PhotoEditor(env_config.sliders_to_use) inference_config.device = device preprocessor_agent = InferenceAgent(inference_env, inference_config) preprocessor_agent.device = device os.path.join(preprocessor_agent_path,'models','backbone.pth') preprocessor_agent.load_backbone(os.path.join(preprocessor_agent_path,'models','backbone.pth')) preprocessor_agent.load_actor_weights(os.path.join(preprocessor_agent_path,'models','actor_head.pth')) preprocessor_agent.load_critics_weights(os.path.join(preprocessor_agent_path,'models','qf1_head.pth'), os.path.join(preprocessor_agent_path,'models','qf2_head.pth')) return preprocessor_agent,preprocessor_photo_editor def str2bool(v): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.') def main(): current_dir = Path(__file__).parent.absolute() parser = argparse.ArgumentParser() parser.add_argument('experiment_path', help='folder containing the experiment models') parser.add_argument('--deterministic', type=str2bool, nargs='?', const=True, default=False) # parser.add_argument('--pre_load_images', type=str2bool, nargs='?', const=True, default=True) parser.add_argument('--logger_level', type=int, default=logging.INFO) parser.add_argument('--device', nargs='?',type=str, default='cuda:0') parser.add_argument('--plt_samples', nargs='?',type=int, default=3) args = parser.parse_args() logger = logging.getLogger("test") args.device = torch.device(args.device) if torch.cuda.is_available() else torch.device('cpu') # Configure logging to console console_handler = logging.StreamHandler() console_handler.setLevel(args.logger_level) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') console_handler.setFormatter(formatter) logger.addHandler(console_handler) logger.setLevel(args.logger_level) with open(os.path.join(current_dir,"configs/inference_config.yaml")) as f: inf_config_dict =yaml.load(f, Loader=yaml.FullLoader) with open(os.path.join(args.experiment_path,"configs/sac_config.yaml")) as f: sac_config_dict =yaml.load(f, Loader=yaml.FullLoader) with open(os.path.join(args.experiment_path,"configs/env_config.yaml")) as f: env_config_dict =yaml.load(f, Loader=yaml.FullLoader) inference_config = Config(inf_config_dict) sac_config = Config(sac_config_dict) env_config = Config(env_config_dict) if hasattr(env_config,'preprocessor_agent_path')==False: env_config.preprocessor_agent_path = None SEED = sac_config.seed random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) torch.backends.cudnn.deterministic = sac_config.torch_deterministic torch.autograd.set_detect_anomaly(True) inference_config.device = args.device photo_editor = PhotoEditor(env_config.sliders_to_use) inference_env = PhotoEnhancementEnvTest( batch_size=env_config.train_batch_size, imsize=env_config.imsize, training_mode=False, done_threshold=env_config.threshold_psnr, edit_sliders=env_config.sliders_to_use, features_size=env_config.features_size, discretize=env_config.discretize, discretize_step= env_config.discretize_step, use_txt_features=env_config.use_txt_features if hasattr(env_config,'use_txt_features') else False, augment_data=env_config.augment_data if hasattr(env_config,'augment_data') else False, pre_encoding_device= args.device, pre_load_images = False, preprocessor_agent_path=None, logger=None )# useless just to get the action space size for the Networks and whether to use txt features or not inf_agent = InferenceAgent(inference_env, inference_config) os.path.join(args.experiment_path,'models','backbone.pth') inf_agent.load_backbone(os.path.join(args.experiment_path,'models','backbone.pth')) inf_agent.load_actor_weights(os.path.join(args.experiment_path,'models','actor_head.pth')) inf_agent.load_critics_weights(os.path.join(args.experiment_path,'models','qf1_head.pth'), os.path.join(args.experiment_path,'models','qf2_head.pth')) if env_config.preprocessor_agent_path is not None: preprocessor_agent,preprocessor_photo_editor = load_preprocessor_agent(env_config.preprocessor_agent_path,args.device) ssim_metric = StructuralSimilarityIndexMeasure().to(args.device) test_512 = create_dataloaders(batch_size=1,image_size=env_config.imsize,use_txt_features=False, train=False,augment_data=False,shuffle=False,resize=False,pre_encoding_device=args.device,pre_load_images=False) test_resized = create_dataloaders(batch_size=500,image_size=env_config.imsize,use_txt_features=env_config.use_txt_features if hasattr(env_config,'use_txt_features') else False, train=False,augment_data=False,shuffle=False,resize=True,pre_encoding_device=args.device, pre_load_images=True) PSNRS = [] SSIM = [] logger.info(f'Testing ...') logger.info(f'Using device {args.device}') # batch_64_images = next(iter(test_64))[0]/255.0 inference_env.dataloader = test_resized inference_env.iter_dataloader = iter(test_resized) inference_env.batch_size = 500 batch_images = inference_env.reset() logger.info(f'Computing optimal enhancement sliders values, DETERMINISTIC:{args.deterministic}') if env_config.preprocessor_agent_path is not None: pre_parameters = preprocessor_agent.act(batch_images,deterministic=args.deterministic) preprocessed_images = preprocessor_photo_editor(batch_images.permute(0,2,3,1), pre_parameters) preprocessed_images = preprocessed_images.permute(0,3,1,2) else: preprocessed_images = batch_images parameters = inf_agent.act(preprocessed_images,deterministic=args.deterministic) logger.info(f'Done') parameter_counter = 0 logger.info(f'Enhancing images and computing metrics') plot_data =[] random_indices = random.sample(range(len(test_512)), args.plt_samples) for i,t in tqdm(test_512, position=0, leave=True): source = i/255.0 target = t/255.0 if env_config.preprocessor_agent_path is not None: enhanced_image = source.permute(0,2,3,1) enhanced_image = preprocessor_photo_editor(enhanced_image.to(args.device), pre_parameters[parameter_counter].unsqueeze(0).to(args.device)) else: enhanced_image = source.permute(0,2,3,1) enhanced_image = photo_editor(enhanced_image.to(args.device),parameters[parameter_counter].unsqueeze(0).to(args.device)) enhanced_image = enhanced_image.permute(0,3,1,2) # B,C,H,W psnr = inference_env.compute_rewards(enhanced_image.to(args.device),target.to(args.device)).item()+50 ssim = ssim_metric(enhanced_image.to(args.device),target.to(args.device)).item() PSNRS.append(psnr) SSIM.append(ssim) if parameter_counter in random_indices: enhanced_image = enhanced_image.permute(0,2,3,1) # B,H,W,C plot_data.append((source.cpu(),enhanced_image.cpu(),target.cpu(),psnr,ssim)) parameter_counter+=1 mean_PSNRS = round(np.mean(PSNRS),2) mean_SSIM = round(np.mean(SSIM),3) logger.info(f'Mean PSNR on MIT 5K Dataset {mean_PSNRS}') logger.info(f'Mean SSIM on MIT 5K Dataset {mean_SSIM}') # Plotting fig, axes = plt.subplots(3, args.plt_samples, figsize=(15, args.plt_samples*5)) # plt.subplots_adjust(hspace=0.5) logger.info(f'Plotting samples') for index in range(args.plt_samples): plot_data[index][0] axes[0][index].imshow(plot_data[index][0][0].permute(1,2,0)) # axes[0][0].set_title(('source_img')) axes[0][index].axis('off') axes[1][index].imshow(plot_data[index][1][0]) # axes[1][index].set_title('Ours') axes[1][index].axis('off') axes[1][index].text(0.5, -0.04, f'PSNR:{round(plot_data[index][3],2)}, SSIM:{round(plot_data[index][4],2)}', size=10, ha='center', transform=axes[1][index].transAxes) axes[2][index].imshow(plot_data[index][2][0].permute(1,2,0)) axes[2][index].axis('off') plt.tight_layout() logger.info(f'Saving plot in {os.path.join(args.experiment_path,"samples_plot.svg")}') fig.savefig(os.path.join(args.experiment_path,"samples_plot.svg"), format='svg') if __name__ == "__main__": main()