Arrcttacsrks commited on
Commit
6c29300
·
verified ·
1 Parent(s): 3303b0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -87
app.py CHANGED
@@ -13,105 +13,91 @@ sys.path.append(root_path)
13
  from test_code.inference import super_resolve_img
14
  from test_code.test_utils import load_grl, load_rrdb, load_dat
15
 
16
-
17
  def auto_download_if_needed(weight_path):
18
  if os.path.exists(weight_path):
19
  return
20
-
21
  if not os.path.exists("pretrained"):
22
  os.makedirs("pretrained")
23
-
24
- # Download pretrained weights based on the model type
25
- model_weights = {
26
- "pretrained/4x_APISR_RRDB_GAN_generator.pth": "https://github.com/Kiteretsu77/APISR/releases/download/v0.2.0/4x_APISR_RRDB_GAN_generator.pth",
27
- "pretrained/4x_APISR_GRL_GAN_generator.pth": "https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth",
28
- "pretrained/2x_APISR_RRDB_GAN_generator.pth": "https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth",
29
- "pretrained/4x_APISR_DAT_GAN_generator.pth": "https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth"
30
  }
31
-
32
- if weight_path in model_weights:
33
- os.system(f"wget {model_weights[weight_path]} -P pretrained")
34
-
35
-
36
- # Define functions to load models into CPU if no GPU is available
37
- def load_grl_cpu(weight_path, scale=4):
38
- state_dict = torch.load(weight_path, map_location=torch.device('cpu'))
39
- generator = load_grl(generator_weight_PATH=weight_path, scale=scale)
40
- generator.load_state_dict(state_dict)
41
- return generator
42
-
43
- def load_rrdb_cpu(weight_path, scale=4):
44
- state_dict = torch.load(weight_path, map_location=torch.device('cpu'))
45
- generator = load_rrdb(generator_weight_PATH=weight_path, scale=scale)
46
- generator.load_state_dict(state_dict)
47
- return generator
48
-
49
- def load_dat_cpu(weight_path, scale=4):
50
- state_dict = torch.load(weight_path, map_location=torch.device('cpu'))
51
- generator = load_dat(generator_weight_PATH=weight_path, scale=scale)
52
- generator.load_state_dict(state_dict)
53
- return generator
54
-
55
 
56
  def inference(img_path, model_name):
57
  try:
 
 
58
  weight_dtype = torch.float32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # Load the model based on the selected model_name
61
- if model_name == "4xGRL":
62
- weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
63
- auto_download_if_needed(weight_path)
64
- generator = load_grl_cpu(weight_path, scale=4)
65
-
66
- elif model_name == "4xRRDB":
67
- weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
68
- auto_download_if_needed(weight_path)
69
- generator = load_rrdb_cpu(weight_path, scale=4)
70
-
71
- elif model_name == "2xRRDB":
72
- weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
73
- auto_download_if_needed(weight_path)
74
- generator = load_rrdb_cpu(weight_path, scale=2)
75
-
76
- elif model_name == "4xDAT":
77
- weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
78
- auto_download_if_needed(weight_path)
79
- generator = load_dat_cpu(weight_path, scale=4)
80
-
81
- else:
82
- raise gr.Error("We don't support such Model")
83
-
84
- generator = generator.to(dtype=weight_dtype)
85
-
86
- print("We are processing ", img_path)
87
- print("The time now is ", datetime.datetime.now(pytz.timezone('US/Eastern')))
88
 
89
- # Super-resolve the image
90
  super_resolved_img = super_resolve_img(
91
- generator, img_path, output_path=None,
92
- weight_dtype=weight_dtype, downsample_threshold=720, crop_for_4x=True
 
 
 
 
93
  )
94
-
95
- # Save and display the output
96
- store_name = str(time.time()) + ".png"
97
  save_image(super_resolved_img, store_name)
98
  outputs = cv2.imread(store_name)
99
  outputs = cv2.cvtColor(outputs, cv2.COLOR_RGB2BGR)
100
  os.remove(store_name)
101
 
102
  return outputs
103
-
104
  except Exception as error:
105
- raise gr.Error(f"global exception: {error}")
106
-
107
 
108
  if __name__ == '__main__':
109
  MARKDOWN = """
110
  ## <p style='text-align: center'> APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) </p>
 
111
  [GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598)
 
112
  APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios.
113
- ### Note: Due to memory restriction, all images whose short side is over 720 pixel will be downsampled to 720 pixel with the same aspect ratio.
114
- ### If APISR is helpful, please help star the [GitHub Repo](https://github.com/Kiteretsu77/APISR). Thanks!
 
 
 
115
  """
116
 
117
  block = gr.Blocks().queue(max_size=10)
@@ -125,7 +111,7 @@ if __name__ == '__main__':
125
  ["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"],
126
  type="value",
127
  value="4xGRL",
128
- label="model"
129
  )
130
  run_btn = gr.Button(value="Submit")
131
 
@@ -133,20 +119,17 @@ if __name__ == '__main__':
133
  output_image = gr.Image(type="numpy", label="Output image")
134
 
135
  with gr.Row(elem_classes=["container"]):
136
- gr.Examples(
137
- [
138
- ["__assets__/lr_inputs/image-00277.png"],
139
- ["__assets__/lr_inputs/image-00542.png"],
140
- ["__assets__/lr_inputs/41.png"],
141
- ["__assets__/lr_inputs/f91.jpg"],
142
- ["__assets__/lr_inputs/image-00440.png"],
143
- ["__assets__/lr_inputs/image-00164.jpg"],
144
- ["__assets__/lr_inputs/img_eva.jpeg"],
145
- ["__assets__/lr_inputs/naruto.jpg"],
146
- ],
147
- [input_image],
148
- )
149
 
150
  run_btn.click(inference, inputs=[input_image, model_name], outputs=[output_image])
151
 
152
- block.launch()
 
13
  from test_code.inference import super_resolve_img
14
  from test_code.test_utils import load_grl, load_rrdb, load_dat
15
 
 
16
  def auto_download_if_needed(weight_path):
17
  if os.path.exists(weight_path):
18
  return
19
+
20
  if not os.path.exists("pretrained"):
21
  os.makedirs("pretrained")
22
+
23
+ weight_mappings = {
24
+ "pretrained/4x_APISR_RRDB_GAN_generator.pth": "v0.2.0/4x_APISR_RRDB_GAN_generator.pth",
25
+ "pretrained/4x_APISR_GRL_GAN_generator.pth": "v0.1.0/4x_APISR_GRL_GAN_generator.pth",
26
+ "pretrained/2x_APISR_RRDB_GAN_generator.pth": "v0.1.0/2x_APISR_RRDB_GAN_generator.pth",
27
+ "pretrained/4x_APISR_DAT_GAN_generator.pth": "v0.3.0/4x_APISR_DAT_GAN_generator.pth"
 
28
  }
29
+
30
+ if weight_path in weight_mappings:
31
+ version_path = weight_mappings[weight_path]
32
+ filename = os.path.basename(weight_path)
33
+ os.system(f"wget https://github.com/Kiteretsu77/APISR/releases/download/{version_path}")
34
+ os.system(f"mv {filename} pretrained")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def inference(img_path, model_name):
37
  try:
38
+ # Determine device - use GPU if available, otherwise CPU
39
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40
  weight_dtype = torch.float32
41
+
42
+ # Load the model with appropriate device mapping
43
+ model_configs = {
44
+ "4xGRL": ("pretrained/4x_APISR_GRL_GAN_generator.pth", load_grl, 4),
45
+ "4xRRDB": ("pretrained/4x_APISR_RRDB_GAN_generator.pth", load_rrdb, 4),
46
+ "2xRRDB": ("pretrained/2x_APISR_RRDB_GAN_generator.pth", load_rrdb, 2),
47
+ "4xDAT": ("pretrained/4x_APISR_DAT_GAN_generator.pth", load_dat, 4)
48
+ }
49
+
50
+ if model_name not in model_configs:
51
+ raise gr.Error("Unsupported model selected")
52
+
53
+ weight_path, loader_func, scale = model_configs[model_name]
54
+ auto_download_if_needed(weight_path)
55
+
56
+ # Load model with explicit device mapping
57
+ generator = loader_func(
58
+ weight_path,
59
+ scale=scale,
60
+ map_location=device
61
+ )
62
+ generator = generator.to(device=device, dtype=weight_dtype)
63
 
64
+ print(f"Processing {img_path} on {device}")
65
+ print(f"Current time: {datetime.datetime.now(pytz.timezone('US/Eastern'))}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ # Process image
68
  super_resolved_img = super_resolve_img(
69
+ generator,
70
+ img_path,
71
+ output_path=None,
72
+ weight_dtype=weight_dtype,
73
+ downsample_threshold=720,
74
+ crop_for_4x=True
75
  )
76
+
77
+ # Save and convert output
78
+ store_name = f"output_{time.time()}.png"
79
  save_image(super_resolved_img, store_name)
80
  outputs = cv2.imread(store_name)
81
  outputs = cv2.cvtColor(outputs, cv2.COLOR_RGB2BGR)
82
  os.remove(store_name)
83
 
84
  return outputs
85
+
86
  except Exception as error:
87
+ raise gr.Error(f"Error during processing: {str(error)}")
 
88
 
89
  if __name__ == '__main__':
90
  MARKDOWN = """
91
  ## <p style='text-align: center'> APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) </p>
92
+
93
  [GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598)
94
+
95
  APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios.
96
+
97
+ ### Note: Due to memory restriction, all images whose short side is over 720 pixel will be downsampled to 720 pixel with the same aspect ratio. E.g., 1920x1080 -> 1280x720
98
+ ### Note: Please check [Model Zoo](https://github.com/Kiteretsu77/APISR/blob/main/docs/model_zoo.md) for the description of each weight and [Here](https://imgsli.com/MjU0MjI0) for model comparisons.
99
+
100
+ ### If APISR is helpful, please help star the [GitHub Repo](https://github.com/Kiteretsu77/APISR). Thanks! ###
101
  """
102
 
103
  block = gr.Blocks().queue(max_size=10)
 
111
  ["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"],
112
  type="value",
113
  value="4xGRL",
114
+ label="Model"
115
  )
116
  run_btn = gr.Button(value="Submit")
117
 
 
119
  output_image = gr.Image(type="numpy", label="Output image")
120
 
121
  with gr.Row(elem_classes=["container"]):
122
+ gr.Examples([
123
+ ["__assets__/lr_inputs/image-00277.png"],
124
+ ["__assets__/lr_inputs/image-00542.png"],
125
+ ["__assets__/lr_inputs/41.png"],
126
+ ["__assets__/lr_inputs/f91.jpg"],
127
+ ["__assets__/lr_inputs/image-00440.png"],
128
+ ["__assets__/lr_inputs/image-00164.jpg"],
129
+ ["__assets__/lr_inputs/img_eva.jpeg"],
130
+ ["__assets__/lr_inputs/naruto.jpg"],
131
+ ], [input_image])
 
 
 
132
 
133
  run_btn.click(inference, inputs=[input_image, model_name], outputs=[output_image])
134
 
135
+ block.launch()