Arrcttacsrks commited on
Commit
c8b34b9
·
verified ·
1 Parent(s): 965d342

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -10
app.py CHANGED
@@ -1,6 +1,3 @@
1
- '''
2
- Gradio demo (almost the same code as the one used in Huggingface space)
3
- '''
4
  import os, sys
5
  import cv2
6
  import time
@@ -40,7 +37,6 @@ def auto_download_if_needed(weight_path):
40
  if weight_path == "pretrained/4x_APISR_DAT_GAN_generator.pth":
41
  os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth")
42
  os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained")
43
-
44
 
45
 
46
  def inference(img_path, model_name):
@@ -52,27 +48,28 @@ def inference(img_path, model_name):
52
  if model_name == "4xGRL":
53
  weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
54
  auto_download_if_needed(weight_path)
55
- generator = load_grl(weight_path, scale=4) # Directly use default way now
56
 
57
  elif model_name == "4xRRDB":
58
  weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
59
  auto_download_if_needed(weight_path)
60
- generator = load_rrdb(weight_path, scale=4) # Directly use default way now
61
 
62
  elif model_name == "2xRRDB":
63
  weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
64
  auto_download_if_needed(weight_path)
65
- generator = load_rrdb(weight_path, scale=2) # Directly use default way now
66
 
67
  elif model_name == "4xDAT":
68
  weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
69
  auto_download_if_needed(weight_path)
70
- generator = load_dat(weight_path, scale=4) # Directly use default way now
71
 
72
  else:
73
  raise gr.Error("We don't support such Model")
74
 
75
- generator = generator.to(dtype=weight_dtype)
 
76
 
77
 
78
  print("We are processing ", img_path)
@@ -150,4 +147,4 @@ if __name__ == '__main__':
150
 
151
  run_btn.click(inference, inputs=[input_image, model_name], outputs=[output_image])
152
 
153
- block.launch()
 
 
 
 
1
  import os, sys
2
  import cv2
3
  import time
 
37
  if weight_path == "pretrained/4x_APISR_DAT_GAN_generator.pth":
38
  os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth")
39
  os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained")
 
40
 
41
 
42
  def inference(img_path, model_name):
 
48
  if model_name == "4xGRL":
49
  weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
50
  auto_download_if_needed(weight_path)
51
+ generator = load_grl(weight_path, scale=4)
52
 
53
  elif model_name == "4xRRDB":
54
  weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
55
  auto_download_if_needed(weight_path)
56
+ generator = load_rrdb(weight_path, scale=4)
57
 
58
  elif model_name == "2xRRDB":
59
  weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
60
  auto_download_if_needed(weight_path)
61
+ generator = load_rrdb(weight_path, scale=2)
62
 
63
  elif model_name == "4xDAT":
64
  weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
65
  auto_download_if_needed(weight_path)
66
+ generator = load_dat(weight_path, scale=4)
67
 
68
  else:
69
  raise gr.Error("We don't support such Model")
70
 
71
+ # Move the model to the CPU
72
+ generator = generator.to(device='cpu')
73
 
74
 
75
  print("We are processing ", img_path)
 
147
 
148
  run_btn.click(inference, inputs=[input_image, model_name], outputs=[output_image])
149
 
150
+ block.launch()