Joypop commited on
Commit
9d3cc59
·
verified ·
1 Parent(s): e518cbe

Update GDPOSR/inferences/test.py

Browse files
Files changed (1) hide show
  1. GDPOSR/inferences/test.py +7 -7
GDPOSR/inferences/test.py CHANGED
@@ -34,10 +34,10 @@ def get_validation_prompt(args, image, model, device='cuda'):
34
 
35
  if __name__ == "__main__":
36
  parser = argparse.ArgumentParser()
37
- parser.add_argument('--input_image', type=str, default="", help='path to the input image')
38
  parser.add_argument('--model_name', type=str, default='realsr', help='name of the pretrained model to be used')
39
  parser.add_argument('--pretrained_path', type=str, default='', help='path to a model state dict to be used')
40
- parser.add_argument('--output_dir', type=str, default='', help='the directory to save the output')
41
  parser.add_argument('--seed', type=int, default=42, help='Random seed to be used')
42
  parser.add_argument("--process_size", type=int, default=512)
43
  parser.add_argument("--upscale", type=int, default=4)
@@ -54,10 +54,10 @@ if __name__ == "__main__":
54
  model = GDPOSRTest(args)
55
  model.set_eval()
56
 
57
- if os.path.isdir(args.input_image):
58
- image_names = sorted(glob.glob(f'{args.input_image}/*.png'))
59
  else:
60
- image_names = [args.input_image]
61
 
62
  print("=== use ram ===")
63
  model_vlm = ram(pretrained='./ckp/ram_swin_large_14m.pth',
@@ -68,7 +68,7 @@ if __name__ == "__main__":
68
  model_vlm.to("cuda")
69
 
70
  # make the output dir
71
- os.makedirs(args.output_dir, exist_ok=True)
72
  print(f'There are {len(image_names)} images.')
73
  for image_name in image_names:
74
 
@@ -104,4 +104,4 @@ if __name__ == "__main__":
104
  if resize_flag:
105
  output_pil.resize((int(args.upscale*ori_width), int(args.upscale*ori_height)))
106
 
107
- output_pil.save(os.path.join(args.output_dir, bname))
 
34
 
35
  if __name__ == "__main__":
36
  parser = argparse.ArgumentParser()
37
+ parser.add_argument('--input_path', type=str, default="", help='path to the input image')
38
  parser.add_argument('--model_name', type=str, default='realsr', help='name of the pretrained model to be used')
39
  parser.add_argument('--pretrained_path', type=str, default='', help='path to a model state dict to be used')
40
+ parser.add_argument('--output_path', type=str, default='', help='the directory to save the output')
41
  parser.add_argument('--seed', type=int, default=42, help='Random seed to be used')
42
  parser.add_argument("--process_size", type=int, default=512)
43
  parser.add_argument("--upscale", type=int, default=4)
 
54
  model = GDPOSRTest(args)
55
  model.set_eval()
56
 
57
+ if os.path.isdir(args.input_path):
58
+ image_names = sorted(glob.glob(f'{args.input_path}/*.png'))
59
  else:
60
+ image_names = [args.input_path]
61
 
62
  print("=== use ram ===")
63
  model_vlm = ram(pretrained='./ckp/ram_swin_large_14m.pth',
 
68
  model_vlm.to("cuda")
69
 
70
  # make the output dir
71
+ os.makedirs(args.output_path, exist_ok=True)
72
  print(f'There are {len(image_names)} images.')
73
  for image_name in image_names:
74
 
 
104
  if resize_flag:
105
  output_pil.resize((int(args.upscale*ori_width), int(args.upscale*ori_height)))
106
 
107
+ output_pil.save(os.path.join(args.output_path, bname))