Arrcttacsrks commited on
Commit
f197031
·
verified ·
1 Parent(s): c8b34b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -10
app.py CHANGED
@@ -42,41 +42,36 @@ def auto_download_if_needed(weight_path):
42
  def inference(img_path, model_name):
43
 
44
  try:
45
- weight_dtype = torch.float32
46
-
47
  # Load the model
48
  if model_name == "4xGRL":
49
  weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
50
  auto_download_if_needed(weight_path)
51
- generator = load_grl(weight_path, scale=4)
52
 
53
  elif model_name == "4xRRDB":
54
  weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
55
  auto_download_if_needed(weight_path)
56
- generator = load_rrdb(weight_path, scale=4)
57
 
58
  elif model_name == "2xRRDB":
59
  weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
60
  auto_download_if_needed(weight_path)
61
- generator = load_rrdb(weight_path, scale=2)
62
 
63
  elif model_name == "4xDAT":
64
  weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
65
  auto_download_if_needed(weight_path)
66
- generator = load_dat(weight_path, scale=4)
67
 
68
  else:
69
  raise gr.Error("We don't support such Model")
70
-
71
- # Move the model to the CPU
72
- generator = generator.to(device='cpu')
73
 
74
 
75
  print("We are processing ", img_path)
76
  print("The time now is ", datetime.datetime.now(pytz.timezone('US/Eastern')))
77
 
78
  # In default, we will automatically use crop to match 4x size
79
- super_resolved_img = super_resolve_img(generator, img_path, output_path=None, weight_dtype=weight_dtype, downsample_threshold=720, crop_for_4x=True)
80
  store_name = str(time.time()) + ".png"
81
  save_image(super_resolved_img, store_name)
82
  outputs = cv2.imread(store_name)
 
42
  def inference(img_path, model_name):
43
 
44
  try:
 
 
45
  # Load the model
46
  if model_name == "4xGRL":
47
  weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
48
  auto_download_if_needed(weight_path)
49
+ generator = load_grl(weight_path, scale=4, device='cpu')
50
 
51
  elif model_name == "4xRRDB":
52
  weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
53
  auto_download_if_needed(weight_path)
54
+ generator = load_rrdb(weight_path, scale=4, device='cpu')
55
 
56
  elif model_name == "2xRRDB":
57
  weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
58
  auto_download_if_needed(weight_path)
59
+ generator = load_rrdb(weight_path, scale=2, device='cpu')
60
 
61
  elif model_name == "4xDAT":
62
  weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
63
  auto_download_if_needed(weight_path)
64
+ generator = load_dat(weight_path, scale=4, device='cpu')
65
 
66
  else:
67
  raise gr.Error("We don't support such Model")
 
 
 
68
 
69
 
70
  print("We are processing ", img_path)
71
  print("The time now is ", datetime.datetime.now(pytz.timezone('US/Eastern')))
72
 
73
  # In default, we will automatically use crop to match 4x size
74
+ super_resolved_img = super_resolve_img(generator, img_path, output_path=None, downsample_threshold=720, crop_for_4x=True)
75
  store_name = str(time.time()) + ".png"
76
  save_image(super_resolved_img, store_name)
77
  outputs = cv2.imread(store_name)