Spaces:
Running on Zero
Running on Zero
| import argparse | |
| from models.model import NamedCurves | |
| import torch | |
| import os | |
| from omegaconf import OmegaConf | |
| from glob import glob | |
| from PIL import Image | |
| from torchvision.transforms import functional as TF | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--input_path', type=str, default='assets/a4957-input.png') | |
| parser.add_argument('--output_path', type=str, default='output/') | |
| parser.add_argument('--model_path', type=str, default='/home/dserrano/Workspace/Color-Naming-Image-Enhancement/pretrained/mit5k_uegan_psnr_25.59.pth') | |
| parser.add_argument('--config_path', type=str, default='configs/mit5k_dpe_config.yaml') | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| config = OmegaConf.load(args.config_path) | |
| model = NamedCurves(config.model).cuda() | |
| model.load_state_dict(torch.load(args.model_path)["model_state_dict"]) | |
| if not os.path.exists(args.output_path): | |
| os.makedirs(args.output_path) | |
| #check if input_path is a folder | |
| if os.path.isdir(args.input_path): | |
| input_paths = glob(sorted(args.input_path + '/*')) | |
| else: | |
| input_paths = [args.input_path] | |
| for input_path in input_paths: | |
| input_tensor = TF.to_tensor(Image.open(input_path)).unsqueeze(0) | |
| output = model(input_tensor.cuda()) | |
| output = TF.to_pil_image(output[0].cpu()) | |
| output.save(os.path.join(args.output_path, os.path.basename(input_path))) | |
| if __name__ == '__main__': | |
| main() |