rastof9 commited on
Commit
1644102
·
verified ·
1 Parent(s): 721f91b

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +20 -26
generate.py CHANGED
@@ -1,7 +1,7 @@
1
  # generate.py
2
- # --- VERSION 12 (Final Fix) ---
3
 
4
- print("--- RUNNING GENERATE.PY VERSION 12 (Final Fix) ---")
5
 
6
  # --- MONKEY-PATCH FOR OLD TORCHVISION ---
7
  try:
@@ -26,9 +26,9 @@ from insightface.app import FaceAnalysis
26
  from insightface.utils import face_align
27
  from huggingface_hub import hf_hub_download
28
  from storage3.utils import StorageException
29
- from realesrgan.archs.srvgg_arch import SRVGGNetCompact
 
30
  from gfpgan import GFPGANer
31
- from basicsr.utils.download_util import load_file_from_url
32
 
33
  import config
34
  import utils
@@ -70,19 +70,20 @@ class GenerationService:
70
  vae=vae, feature_extractor=None, safety_checker=None
71
  ).to(self.device)
72
 
73
- # --- Upscaler Model ---
74
  logger.info("Loading Real-ESRGAN upscaler model...")
75
- # FIX: The downloaded file has uppercase letters. Match the filename exactly.
76
- model_name = 'RealESRGAN_x4plus.pth'
77
- model_path = os.path.join('weights', model_name)
78
-
79
- if not os.path.exists(model_path):
80
- model_url = f'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/{model_name}'
81
- load_file_from_url(url=model_url, model_dir=os.path.join('weights'), progress=True)
82
-
83
- self.upsampler = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
84
- self.upsampler.load_state_dict(torch.load(model_path)['params_ema'])
85
- self.upsampler.to(self.device)
 
86
  logger.info("Upscaler model loaded.")
87
 
88
  logger.info("All models loaded successfully.")
@@ -95,16 +96,9 @@ class GenerationService:
95
  """Upscales an image using Real-ESRGAN."""
96
  try:
97
  img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
98
- img = img.astype('float32') / 255.
99
- img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(self.device)
100
-
101
- with torch.no_grad():
102
- output = self.upsampler(img)
103
-
104
- output_img = output.squeeze().permute(1, 2, 0).cpu().numpy()
105
- output_img = (output_img * 255.0).round().astype('uint8')
106
-
107
- cv2.imwrite(image_path, output_img)
108
  logger.info(f"Successfully upscaled image: {image_path}")
109
  return image_path
110
  except Exception as e:
 
1
  # generate.py
2
+ # --- VERSION 13 (Correct Upscaler Loading) ---
3
 
4
+ print("--- RUNNING GENERATE.PY VERSION 13 (Correct Upscaler Loading) ---")
5
 
6
  # --- MONKEY-PATCH FOR OLD TORCHVISION ---
7
  try:
 
26
  from insightface.utils import face_align
27
  from huggingface_hub import hf_hub_download
28
  from storage3.utils import StorageException
29
+ from realesrgan import RealESRGANer # <-- IMPORT THE CORRECT CLASS
30
+ from basicsr.archs.rrdbnet_arch import RRDBNet
31
  from gfpgan import GFPGANer
 
32
 
33
  import config
34
  import utils
 
70
  vae=vae, feature_extractor=None, safety_checker=None
71
  ).to(self.device)
72
 
73
+ # --- CORRECTED UPSCALER LOADING ---
74
  logger.info("Loading Real-ESRGAN upscaler model...")
75
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
76
+ self.upsampler = RealESRGANer(
77
+ scale=4,
78
+ model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
79
+ dni_weight=None,
80
+ model=model,
81
+ tile=0,
82
+ tile_pad=10,
83
+ pre_pad=0,
84
+ half=True if self.torch_dtype == torch.float16 else False,
85
+ gpu_id=0 if self.device == "cuda" else None
86
+ )
87
  logger.info("Upscaler model loaded.")
88
 
89
  logger.info("All models loaded successfully.")
 
96
  """Upscales an image using Real-ESRGAN."""
97
  try:
98
  img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
99
+ # The enhance method returns the upscaled image and its type
100
+ output, _ = self.upsampler.enhance(img, outscale=4)
101
+ cv2.imwrite(image_path, output)
 
 
 
 
 
 
 
102
  logger.info(f"Successfully upscaled image: {image_path}")
103
  return image_path
104
  except Exception as e: