Arrcttacsrks commited on
Commit
d6d0d80
·
verified ·
1 Parent(s): 5fd62d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -15
app.py CHANGED
@@ -1,6 +1,3 @@
1
- '''
2
- Gradio demo (almost the same code as the one used in Huggingface space)
3
- '''
4
  import os, sys
5
  import cv2
6
  import time
@@ -10,7 +7,6 @@ import torch
10
  import numpy as np
11
  from torchvision.utils import save_image
12
 
13
-
14
  # Import files from the local folder
15
  root_path = os.path.abspath('.')
16
  sys.path.append(root_path)
@@ -25,6 +21,7 @@ def auto_download_if_needed(weight_path):
25
  if not os.path.exists("pretrained"):
26
  os.makedirs("pretrained")
27
 
 
28
  if weight_path == "pretrained/4x_APISR_RRDB_GAN_generator.pth":
29
  os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.2.0/4x_APISR_RRDB_GAN_generator.pth")
30
  os.system("mv 4x_APISR_RRDB_GAN_generator.pth pretrained")
@@ -40,45 +37,64 @@ def auto_download_if_needed(weight_path):
40
  if weight_path == "pretrained/4x_APISR_DAT_GAN_generator.pth":
41
  os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth")
42
  os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained")
43
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46
  def inference(img_path, model_name):
47
-
48
  try:
49
  weight_dtype = torch.float32
50
 
51
- # Load the model
52
  if model_name == "4xGRL":
53
  weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
54
  auto_download_if_needed(weight_path)
55
- generator = load_grl(weight_path, scale=4) # Directly use default way now
56
 
57
  elif model_name == "4xRRDB":
58
  weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
59
  auto_download_if_needed(weight_path)
60
- generator = load_rrdb(weight_path, scale=4) # Directly use default way now
61
 
62
  elif model_name == "2xRRDB":
63
  weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
64
  auto_download_if_needed(weight_path)
65
- generator = load_rrdb(weight_path, scale=2) # Directly use default way now
66
 
67
  elif model_name == "4xDAT":
68
  weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
69
  auto_download_if_needed(weight_path)
70
- generator = load_dat(weight_path, scale=4) # Directly use default way now
71
 
72
  else:
73
  raise gr.Error("We don't support such Model")
74
 
75
  generator = generator.to(dtype=weight_dtype)
76
 
77
-
78
  print("We are processing ", img_path)
79
  print("The time now is ", datetime.datetime.now(pytz.timezone('US/Eastern')))
80
 
81
- # In default, we will automatically use crop to match 4x size
82
  super_resolved_img = super_resolve_img(generator, img_path, output_path=None, weight_dtype=weight_dtype, downsample_threshold=720, crop_for_4x=True)
83
  store_name = str(time.time()) + ".png"
84
  save_image(super_resolved_img, store_name)
@@ -88,12 +104,10 @@ def inference(img_path, model_name):
88
 
89
  return outputs
90
 
91
-
92
  except Exception as error:
93
  raise gr.Error(f"global exception: {error}")
94
 
95
 
96
-
97
  if __name__ == '__main__':
98
 
99
  MARKDOWN = \
 
 
 
 
1
  import os, sys
2
  import cv2
3
  import time
 
7
  import numpy as np
8
  from torchvision.utils import save_image
9
 
 
10
  # Import files from the local folder
11
  root_path = os.path.abspath('.')
12
  sys.path.append(root_path)
 
21
  if not os.path.exists("pretrained"):
22
  os.makedirs("pretrained")
23
 
24
+ # Tải các mô hình vào CPU
25
  if weight_path == "pretrained/4x_APISR_RRDB_GAN_generator.pth":
26
  os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.2.0/4x_APISR_RRDB_GAN_generator.pth")
27
  os.system("mv 4x_APISR_RRDB_GAN_generator.pth pretrained")
 
37
  if weight_path == "pretrained/4x_APISR_DAT_GAN_generator.pth":
38
  os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth")
39
  os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained")
40
+
41
+
42
+ def load_grl_cpu(weight_path, scale=4):
43
+ # Tải mô hình GRL vào CPU
44
+ state_dict = torch.load(weight_path, map_location=torch.device('cpu'))
45
+ generator = load_grl(scale=scale) # Khởi tạo mô hình GRL
46
+ generator.load_state_dict(state_dict)
47
+ return generator
48
+
49
+ def load_rrdb_cpu(weight_path, scale=4):
50
+ # Tải mô hình RRDB vào CPU
51
+ state_dict = torch.load(weight_path, map_location=torch.device('cpu'))
52
+ generator = load_rrdb(scale=scale) # Khởi tạo mô hình RRDB
53
+ generator.load_state_dict(state_dict)
54
+ return generator
55
+
56
+ def load_dat_cpu(weight_path, scale=4):
57
+ # Tải mô hình DAT vào CPU
58
+ state_dict = torch.load(weight_path, map_location=torch.device('cpu'))
59
+ generator = load_dat(scale=scale) # Khởi tạo mô hình DAT
60
+ generator.load_state_dict(state_dict)
61
+ return generator
62
 
63
 
64
  def inference(img_path, model_name):
 
65
  try:
66
  weight_dtype = torch.float32
67
 
68
+ # Load the model based on user selection
69
  if model_name == "4xGRL":
70
  weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
71
  auto_download_if_needed(weight_path)
72
+ generator = load_grl_cpu(weight_path, scale=4)
73
 
74
  elif model_name == "4xRRDB":
75
  weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
76
  auto_download_if_needed(weight_path)
77
+ generator = load_rrdb_cpu(weight_path, scale=4)
78
 
79
  elif model_name == "2xRRDB":
80
  weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
81
  auto_download_if_needed(weight_path)
82
+ generator = load_rrdb_cpu(weight_path, scale=2)
83
 
84
  elif model_name == "4xDAT":
85
  weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
86
  auto_download_if_needed(weight_path)
87
+ generator = load_dat_cpu(weight_path, scale=4)
88
 
89
  else:
90
  raise gr.Error("We don't support such Model")
91
 
92
  generator = generator.to(dtype=weight_dtype)
93
 
 
94
  print("We are processing ", img_path)
95
  print("The time now is ", datetime.datetime.now(pytz.timezone('US/Eastern')))
96
 
97
+ # Run super-resolution and save result
98
  super_resolved_img = super_resolve_img(generator, img_path, output_path=None, weight_dtype=weight_dtype, downsample_threshold=720, crop_for_4x=True)
99
  store_name = str(time.time()) + ".png"
100
  save_image(super_resolved_img, store_name)
 
104
 
105
  return outputs
106
 
 
107
  except Exception as error:
108
  raise gr.Error(f"global exception: {error}")
109
 
110
 
 
111
  if __name__ == '__main__':
112
 
113
  MARKDOWN = \