Update GDPOSR/inferences/test.py
Browse files
GDPOSR/inferences/test.py
CHANGED
|
@@ -34,10 +34,10 @@ def get_validation_prompt(args, image, model, device='cuda'):
|
|
| 34 |
|
| 35 |
if __name__ == "__main__":
|
| 36 |
parser = argparse.ArgumentParser()
|
| 37 |
-
parser.add_argument('--
|
| 38 |
parser.add_argument('--model_name', type=str, default='realsr', help='name of the pretrained model to be used')
|
| 39 |
parser.add_argument('--pretrained_path', type=str, default='', help='path to a model state dict to be used')
|
| 40 |
-
parser.add_argument('--
|
| 41 |
parser.add_argument('--seed', type=int, default=42, help='Random seed to be used')
|
| 42 |
parser.add_argument("--process_size", type=int, default=512)
|
| 43 |
parser.add_argument("--upscale", type=int, default=4)
|
|
@@ -54,10 +54,10 @@ if __name__ == "__main__":
|
|
| 54 |
model = GDPOSRTest(args)
|
| 55 |
model.set_eval()
|
| 56 |
|
| 57 |
-
if os.path.isdir(args.
|
| 58 |
-
image_names = sorted(glob.glob(f'{args.
|
| 59 |
else:
|
| 60 |
-
image_names = [args.
|
| 61 |
|
| 62 |
print("=== use ram ===")
|
| 63 |
model_vlm = ram(pretrained='./ckp/ram_swin_large_14m.pth',
|
|
@@ -68,7 +68,7 @@ if __name__ == "__main__":
|
|
| 68 |
model_vlm.to("cuda")
|
| 69 |
|
| 70 |
# make the output dir
|
| 71 |
-
os.makedirs(args.
|
| 72 |
print(f'There are {len(image_names)} images.')
|
| 73 |
for image_name in image_names:
|
| 74 |
|
|
@@ -104,4 +104,4 @@ if __name__ == "__main__":
|
|
| 104 |
if resize_flag:
|
| 105 |
output_pil.resize((int(args.upscale*ori_width), int(args.upscale*ori_height)))
|
| 106 |
|
| 107 |
-
output_pil.save(os.path.join(args.
|
|
|
|
| 34 |
|
| 35 |
if __name__ == "__main__":
|
| 36 |
parser = argparse.ArgumentParser()
|
| 37 |
+
parser.add_argument('--input_path', type=str, default="", help='path to the input image')
|
| 38 |
parser.add_argument('--model_name', type=str, default='realsr', help='name of the pretrained model to be used')
|
| 39 |
parser.add_argument('--pretrained_path', type=str, default='', help='path to a model state dict to be used')
|
| 40 |
+
parser.add_argument('--output_path', type=str, default='', help='the directory to save the output')
|
| 41 |
parser.add_argument('--seed', type=int, default=42, help='Random seed to be used')
|
| 42 |
parser.add_argument("--process_size", type=int, default=512)
|
| 43 |
parser.add_argument("--upscale", type=int, default=4)
|
|
|
|
| 54 |
model = GDPOSRTest(args)
|
| 55 |
model.set_eval()
|
| 56 |
|
| 57 |
+
if os.path.isdir(args.input_path):
|
| 58 |
+
image_names = sorted(glob.glob(f'{args.input_path}/*.png'))
|
| 59 |
else:
|
| 60 |
+
image_names = [args.input_path]
|
| 61 |
|
| 62 |
print("=== use ram ===")
|
| 63 |
model_vlm = ram(pretrained='./ckp/ram_swin_large_14m.pth',
|
|
|
|
| 68 |
model_vlm.to("cuda")
|
| 69 |
|
| 70 |
# make the output dir
|
| 71 |
+
os.makedirs(args.output_path, exist_ok=True)
|
| 72 |
print(f'There are {len(image_names)} images.')
|
| 73 |
for image_name in image_names:
|
| 74 |
|
|
|
|
| 104 |
if resize_flag:
|
| 105 |
output_pil.resize((int(args.upscale*ori_width), int(args.upscale*ori_height)))
|
| 106 |
|
| 107 |
+
output_pil.save(os.path.join(args.output_path, bname))
|