mastari commited on
Commit
3e3c736
·
1 Parent(s): 34e750a
Files changed (2) hide show
  1. handler.py +18 -17
  2. requirements.txt +2 -2
handler.py CHANGED
@@ -15,24 +15,28 @@ logger = logging.getLogger(__name__)
15
 
16
  class EndpointHandler:
17
  def __init__(self, path="."):
18
- logger.info("🚀 [INIT] Starting GFPGAN + Real-ESRGAN handler...")
19
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
  self.half = self.device == "cuda"
21
  self.path = path
22
 
23
- # Model download URLs
24
- self.gfpgan_model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
25
- self.realesr_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
 
 
 
 
26
 
27
  # Local paths
28
  self.gfpgan_model_path = os.path.join(path, "GFPGANv1.4.pth")
29
  self.realesr_model_path = os.path.join(path, "realesr-general-x4v3.pth")
30
 
31
- # Lazy init placeholders
32
  self.bg_upsampler = None
33
  self.restorer = None
34
 
35
- # Ensure models exist
36
  self._ensure_model(self.gfpgan_model_url, self.gfpgan_model_path)
37
  self._ensure_model(self.realesr_model_url, self.realesr_model_path)
38
 
@@ -48,11 +52,12 @@ class EndpointHandler:
48
  logger.info(f"✅ Saved to {local_path}")
49
 
50
  def _init_models(self):
51
- """Lazy-load GFPGAN and RealESRGAN models only when needed."""
52
  if self.bg_upsampler is None:
53
- logger.info("🧩 Initializing Real-ESRGAN background upsampler...")
54
  model = SRVGGNetCompact(
55
- num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu'
 
56
  )
57
  self.bg_upsampler = RealESRGANer(
58
  scale=4,
@@ -77,14 +82,13 @@ class EndpointHandler:
77
  logger.info("✅ Models ready!")
78
 
79
  def __call__(self, data):
80
- """Run restoration on an input image (bytes or URL)."""
81
  self._init_models()
82
 
83
- image = None
84
  if isinstance(data, dict) and "inputs" in data:
85
  data = data["inputs"]
86
 
87
- # Accept bytes, PIL, URL, or base64 (Gradio style)
88
  if isinstance(data, (bytes, bytearray)):
89
  image = Image.open(io.BytesIO(data)).convert("RGB")
90
  elif isinstance(data, str) and data.startswith("http"):
@@ -96,14 +100,11 @@ class EndpointHandler:
96
  raise ValueError("Unsupported input type")
97
 
98
  input_img = np.array(image, dtype=np.uint8)
 
99
  cropped_faces, restored_faces, restored_img = self.restorer.enhance(
100
- input_img,
101
- has_aligned=False,
102
- only_center_face=False,
103
- paste_back=True
104
  )
105
 
106
- # Convert restored image to bytes for output
107
  _, buffer = cv2.imencode(".jpg", restored_img)
108
  output_bytes = io.BytesIO(buffer.tobytes())
109
 
 
15
 
16
  class EndpointHandler:
17
  def __init__(self, path="."):
18
+ logger.info("🚀 [INIT] GFPGAN + Real-ESRGAN handler starting...")
19
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
  self.half = self.device == "cuda"
21
  self.path = path
22
 
23
+ # Model URLs
24
+ self.gfpgan_model_url = (
25
+ "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
26
+ )
27
+ self.realesr_model_url = (
28
+ "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
29
+ )
30
 
31
  # Local paths
32
  self.gfpgan_model_path = os.path.join(path, "GFPGANv1.4.pth")
33
  self.realesr_model_path = os.path.join(path, "realesr-general-x4v3.pth")
34
 
35
+ # Lazy init
36
  self.bg_upsampler = None
37
  self.restorer = None
38
 
39
+ # Ensure model files exist
40
  self._ensure_model(self.gfpgan_model_url, self.gfpgan_model_path)
41
  self._ensure_model(self.realesr_model_url, self.realesr_model_path)
42
 
 
52
  logger.info(f"✅ Saved to {local_path}")
53
 
54
  def _init_models(self):
55
+ """Lazy-load GFPGAN and Real-ESRGAN models."""
56
  if self.bg_upsampler is None:
57
+ logger.info("🧩 Initializing Real-ESRGAN upsampler...")
58
  model = SRVGGNetCompact(
59
+ num_in_ch=3, num_out_ch=3, num_feat=64,
60
+ num_conv=32, upscale=4, act_type="prelu"
61
  )
62
  self.bg_upsampler = RealESRGANer(
63
  scale=4,
 
82
  logger.info("✅ Models ready!")
83
 
84
  def __call__(self, data):
85
+ """Restore a face photo."""
86
  self._init_models()
87
 
 
88
  if isinstance(data, dict) and "inputs" in data:
89
  data = data["inputs"]
90
 
91
+ # Load image
92
  if isinstance(data, (bytes, bytearray)):
93
  image = Image.open(io.BytesIO(data)).convert("RGB")
94
  elif isinstance(data, str) and data.startswith("http"):
 
100
  raise ValueError("Unsupported input type")
101
 
102
  input_img = np.array(image, dtype=np.uint8)
103
+
104
  cropped_faces, restored_faces, restored_img = self.restorer.enhance(
105
+ input_img, has_aligned=False, only_center_face=False, paste_back=True
 
 
 
106
  )
107
 
 
108
  _, buffer = cv2.imencode(".jpg", restored_img)
109
  output_bytes = io.BytesIO(buffer.tobytes())
110
 
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- torch>=2.1
2
- torchvision>=0.16
3
  gfpgan==1.3.8
4
  realesrgan==0.3.0
5
  basicsr==1.4.2
 
1
+ torch==2.1.0
2
+ torchvision==0.16.0
3
  gfpgan==1.3.8
4
  realesrgan==0.3.0
5
  basicsr==1.4.2