Arrcttacsrks commited on
Commit
3303b0f
·
verified ·
1 Parent(s): d6d0d80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -51
app.py CHANGED
@@ -17,46 +17,38 @@ from test_code.test_utils import load_grl, load_rrdb, load_dat
17
  def auto_download_if_needed(weight_path):
18
  if os.path.exists(weight_path):
19
  return
20
-
21
  if not os.path.exists("pretrained"):
22
  os.makedirs("pretrained")
23
-
24
- # Tải các hình vào CPU
25
- if weight_path == "pretrained/4x_APISR_RRDB_GAN_generator.pth":
26
- os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.2.0/4x_APISR_RRDB_GAN_generator.pth")
27
- os.system("mv 4x_APISR_RRDB_GAN_generator.pth pretrained")
28
-
29
- if weight_path == "pretrained/4x_APISR_GRL_GAN_generator.pth":
30
- os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth")
31
- os.system("mv 4x_APISR_GRL_GAN_generator.pth pretrained")
32
-
33
- if weight_path == "pretrained/2x_APISR_RRDB_GAN_generator.pth":
34
- os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth")
35
- os.system("mv 2x_APISR_RRDB_GAN_generator.pth pretrained")
36
-
37
- if weight_path == "pretrained/4x_APISR_DAT_GAN_generator.pth":
38
- os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth")
39
- os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained")
40
 
41
 
 
42
  def load_grl_cpu(weight_path, scale=4):
43
- # Tải mô hình GRL vào CPU
44
  state_dict = torch.load(weight_path, map_location=torch.device('cpu'))
45
- generator = load_grl(scale=scale) # Khởi tạo mô hình GRL
46
  generator.load_state_dict(state_dict)
47
  return generator
48
 
49
  def load_rrdb_cpu(weight_path, scale=4):
50
- # Tải mô hình RRDB vào CPU
51
  state_dict = torch.load(weight_path, map_location=torch.device('cpu'))
52
- generator = load_rrdb(scale=scale) # Khởi tạo mô hình RRDB
53
  generator.load_state_dict(state_dict)
54
  return generator
55
 
56
  def load_dat_cpu(weight_path, scale=4):
57
- # Tải mô hình DAT vào CPU
58
  state_dict = torch.load(weight_path, map_location=torch.device('cpu'))
59
- generator = load_dat(scale=scale) # Khởi tạo mô hình DAT
60
  generator.load_state_dict(state_dict)
61
  return generator
62
 
@@ -64,38 +56,43 @@ def load_dat_cpu(weight_path, scale=4):
64
  def inference(img_path, model_name):
65
  try:
66
  weight_dtype = torch.float32
67
-
68
- # Load the model based on user selection
69
  if model_name == "4xGRL":
70
  weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
71
  auto_download_if_needed(weight_path)
72
  generator = load_grl_cpu(weight_path, scale=4)
73
-
74
  elif model_name == "4xRRDB":
75
  weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
76
  auto_download_if_needed(weight_path)
77
  generator = load_rrdb_cpu(weight_path, scale=4)
78
-
79
  elif model_name == "2xRRDB":
80
  weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
81
  auto_download_if_needed(weight_path)
82
  generator = load_rrdb_cpu(weight_path, scale=2)
83
-
84
  elif model_name == "4xDAT":
85
  weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
86
  auto_download_if_needed(weight_path)
87
  generator = load_dat_cpu(weight_path, scale=4)
88
-
89
  else:
90
  raise gr.Error("We don't support such Model")
91
-
92
  generator = generator.to(dtype=weight_dtype)
93
 
94
  print("We are processing ", img_path)
95
  print("The time now is ", datetime.datetime.now(pytz.timezone('US/Eastern')))
96
 
97
- # Run super-resolution and save result
98
- super_resolved_img = super_resolve_img(generator, img_path, output_path=None, weight_dtype=weight_dtype, downsample_threshold=720, crop_for_4x=True)
 
 
 
 
 
99
  store_name = str(time.time()) + ".png"
100
  save_image(super_resolved_img, store_name)
101
  outputs = cv2.imread(store_name)
@@ -103,25 +100,18 @@ def inference(img_path, model_name):
103
  os.remove(store_name)
104
 
105
  return outputs
106
-
107
  except Exception as error:
108
  raise gr.Error(f"global exception: {error}")
109
 
110
 
111
  if __name__ == '__main__':
112
-
113
- MARKDOWN = \
114
- """
115
  ## <p style='text-align: center'> APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) </p>
116
-
117
  [GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598)
118
-
119
  APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios.
120
-
121
- ### Note: Due to memory restriction, all images whose short side is over 720 pixel will be downsampled to 720 pixel with the same aspect ratio. E.g., 1920x1080 -> 1280x720
122
- ### Note: Please check [Model Zoo](https://github.com/Kiteretsu77/APISR/blob/main/docs/model_zoo.md) for the description of each weight and [Here](https://imgsli.com/MjU0MjI0) for model comparisons.
123
-
124
- ### If APISR is helpful, please help star the [GitHub Repo](https://github.com/Kiteretsu77/APISR). Thanks! ###
125
  """
126
 
127
  block = gr.Blocks().queue(max_size=10)
@@ -132,15 +122,10 @@ if __name__ == '__main__':
132
  with gr.Column(scale=2):
133
  input_image = gr.Image(type="filepath", label="Input")
134
  model_name = gr.Dropdown(
135
- [
136
- "2xRRDB",
137
- "4xRRDB",
138
- "4xGRL",
139
- "4xDAT",
140
- ],
141
  type="value",
142
  value="4xGRL",
143
- label="model",
144
  )
145
  run_btn = gr.Button(value="Submit")
146
 
 
17
  def auto_download_if_needed(weight_path):
18
  if os.path.exists(weight_path):
19
  return
20
+
21
  if not os.path.exists("pretrained"):
22
  os.makedirs("pretrained")
23
+
24
+ # Download pretrained weights based on the model type
25
+ model_weights = {
26
+ "pretrained/4x_APISR_RRDB_GAN_generator.pth": "https://github.com/Kiteretsu77/APISR/releases/download/v0.2.0/4x_APISR_RRDB_GAN_generator.pth",
27
+ "pretrained/4x_APISR_GRL_GAN_generator.pth": "https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth",
28
+ "pretrained/2x_APISR_RRDB_GAN_generator.pth": "https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth",
29
+ "pretrained/4x_APISR_DAT_GAN_generator.pth": "https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth"
30
+ }
31
+
32
+ if weight_path in model_weights:
33
+ os.system(f"wget {model_weights[weight_path]} -P pretrained")
 
 
 
 
 
 
34
 
35
 
36
+ # Define functions to load models into CPU if no GPU is available
37
  def load_grl_cpu(weight_path, scale=4):
 
38
  state_dict = torch.load(weight_path, map_location=torch.device('cpu'))
39
+ generator = load_grl(generator_weight_PATH=weight_path, scale=scale)
40
  generator.load_state_dict(state_dict)
41
  return generator
42
 
43
  def load_rrdb_cpu(weight_path, scale=4):
 
44
  state_dict = torch.load(weight_path, map_location=torch.device('cpu'))
45
+ generator = load_rrdb(generator_weight_PATH=weight_path, scale=scale)
46
  generator.load_state_dict(state_dict)
47
  return generator
48
 
49
  def load_dat_cpu(weight_path, scale=4):
 
50
  state_dict = torch.load(weight_path, map_location=torch.device('cpu'))
51
+ generator = load_dat(generator_weight_PATH=weight_path, scale=scale)
52
  generator.load_state_dict(state_dict)
53
  return generator
54
 
 
56
  def inference(img_path, model_name):
57
  try:
58
  weight_dtype = torch.float32
59
+
60
+ # Load the model based on the selected model_name
61
  if model_name == "4xGRL":
62
  weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
63
  auto_download_if_needed(weight_path)
64
  generator = load_grl_cpu(weight_path, scale=4)
65
+
66
  elif model_name == "4xRRDB":
67
  weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
68
  auto_download_if_needed(weight_path)
69
  generator = load_rrdb_cpu(weight_path, scale=4)
70
+
71
  elif model_name == "2xRRDB":
72
  weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
73
  auto_download_if_needed(weight_path)
74
  generator = load_rrdb_cpu(weight_path, scale=2)
75
+
76
  elif model_name == "4xDAT":
77
  weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
78
  auto_download_if_needed(weight_path)
79
  generator = load_dat_cpu(weight_path, scale=4)
80
+
81
  else:
82
  raise gr.Error("We don't support such Model")
83
+
84
  generator = generator.to(dtype=weight_dtype)
85
 
86
  print("We are processing ", img_path)
87
  print("The time now is ", datetime.datetime.now(pytz.timezone('US/Eastern')))
88
 
89
+ # Super-resolve the image
90
+ super_resolved_img = super_resolve_img(
91
+ generator, img_path, output_path=None,
92
+ weight_dtype=weight_dtype, downsample_threshold=720, crop_for_4x=True
93
+ )
94
+
95
+ # Save and display the output
96
  store_name = str(time.time()) + ".png"
97
  save_image(super_resolved_img, store_name)
98
  outputs = cv2.imread(store_name)
 
100
  os.remove(store_name)
101
 
102
  return outputs
103
+
104
  except Exception as error:
105
  raise gr.Error(f"global exception: {error}")
106
 
107
 
108
  if __name__ == '__main__':
109
+ MARKDOWN = """
 
 
110
  ## <p style='text-align: center'> APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) </p>
 
111
  [GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598)
 
112
  APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios.
113
+ ### Note: Due to memory restriction, all images whose short side is over 720 pixel will be downsampled to 720 pixel with the same aspect ratio.
114
+ ### If APISR is helpful, please help star the [GitHub Repo](https://github.com/Kiteretsu77/APISR). Thanks!
 
 
 
115
  """
116
 
117
  block = gr.Blocks().queue(max_size=10)
 
122
  with gr.Column(scale=2):
123
  input_image = gr.Image(type="filepath", label="Input")
124
  model_name = gr.Dropdown(
125
+ ["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"],
 
 
 
 
 
126
  type="value",
127
  value="4xGRL",
128
+ label="model"
129
  )
130
  run_btn = gr.Button(value="Submit")
131