mastari commited on
Commit
7df3ea4
Β·
1 Parent(s): e26840f
Files changed (1) hide show
  1. handler.py +163 -112
handler.py CHANGED
@@ -1,138 +1,189 @@
1
- import os
2
- import io
3
- import torch
4
  import base64
 
 
 
 
 
 
 
5
  import requests
6
  from PIL import Image
 
7
  from gfpgan import GFPGANer
8
  from realesrgan import RealESRGANer
9
  from basicsr.archs.rrdbnet_arch import RRDBNet
10
 
11
-
12
- class EndpointHandler:
13
- def __init__(self, path="."):
14
- print("πŸš€ [INIT] Starting GFPGAN + RealESRGAN hybrid handler initialization...")
15
- print(f"πŸ“‚ Working directory: {os.getcwd()}")
16
- print(f"πŸ“ Handler path argument: {path}")
17
-
18
- # ------------------------------
19
- # Download GFPGAN v1.4 weights
20
- # ------------------------------
21
- self.model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
22
- self.model_path = os.path.join(path, "GFPGANv1.4.pth")
23
- if not os.path.exists(self.model_path):
24
- print(f"πŸ“₯ [DOWNLOAD] Fetching GFPGAN v1.4 weights...")
25
- r = requests.get(self.model_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  r.raise_for_status()
27
- with open(self.model_path, "wb") as f:
28
  f.write(r.content)
29
- print(f"βœ… [MODEL] Downloaded GFPGAN model to {self.model_path}")
30
- else:
31
- print("βœ… [MODEL] GFPGAN weights already exist locally.")
32
-
33
- # ------------------------------
34
- # Setup background upsampler (Real-ESRGAN)
35
- # ------------------------------
36
- print("🧠 [INIT] Setting up Real-ESRGAN background upsampler...")
37
- rrdbnet = RRDBNet(
38
- num_in_ch=3, num_out_ch=3,
39
- num_feat=64, num_block=23,
40
- num_grow_ch=32, scale=4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  )
42
 
43
  self.bg_upsampler = RealESRGANer(
44
- scale=2,
45
- model_path=None, # auto-download model weights
46
- model=rrdbnet,
47
- tile=400,
48
  tile_pad=10,
49
  pre_pad=0,
50
- half=False,
51
- device="cuda" if torch.cuda.is_available() else "cpu",
52
  )
53
- print("βœ… [INIT] Real-ESRGAN background upsampler ready.")
54
 
55
- # ------------------------------
56
- # Setup GFPGANer
57
- # ------------------------------
58
- print("🧠 [INIT] Setting up GFPGANer (v1.4)...")
59
  self.restorer = GFPGANer(
60
- model_path=self.model_path,
61
  upscale=2,
62
  arch="clean",
63
  channel_multiplier=2,
64
  bg_upsampler=self.bg_upsampler,
65
- device="cuda" if torch.cuda.is_available() else "cpu",
66
  )
67
- print("βœ… [INIT DONE] GFPGAN + RealESRGAN hybrid handler ready.")
68
-
69
- # ----------------------------------------------------------
70
- # Main inference entry point
71
- # ----------------------------------------------------------
72
- def __call__(self, data):
73
- print("πŸ›°οΈ [CALL] Endpoint invoked!")
74
- print(f"πŸ“¦ [CALL] Raw input type: {type(data)}")
75
 
 
 
 
 
76
  try:
77
- image = self.preprocess(data)
78
- print("🧩 [STEP] Image preprocessed successfully.")
79
- restored = self.inference(image)
80
- print("🎨 [STEP] Inference completed successfully.")
81
- return self.postprocess(restored)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  except Exception as e:
83
- print(f"πŸ’₯ [ERROR] Exception during call: {str(e)}")
84
- return {"error": str(e)}
85
-
86
- # ----------------------------------------------------------
87
- # Preprocessing
88
- # ----------------------------------------------------------
89
- def preprocess(self, data):
90
- print("πŸ”§ [PREPROCESS] Starting...")
91
- if isinstance(data, (bytes, bytearray)):
92
- print("πŸ–ΌοΈ [PREPROCESS] Raw bytes detected.")
93
- return Image.open(io.BytesIO(data)).convert("RGB")
94
-
95
- if isinstance(data, dict):
96
- img_field = data.get("inputs") or data.get("image")
97
- if isinstance(img_field, str):
98
- print("🧬 [PREPROCESS] Base64 string detected.")
99
- decoded = base64.b64decode(img_field)
100
- return Image.open(io.BytesIO(decoded)).convert("RGB")
101
- if isinstance(img_field, (bytes, bytearray)):
102
- print("🧩 [PREPROCESS] Byte array detected.")
103
- return Image.open(io.BytesIO(img_field)).convert("RGB")
104
-
105
- raise ValueError("Unsupported input format β€” expected bytes or base64 data.")
106
-
107
- # ----------------------------------------------------------
108
- # Inference
109
- # ----------------------------------------------------------
110
- def inference(self, image):
111
- print("βš™οΈ [INFERENCE] Running GFPGAN + RealESRGAN enhancement...")
112
- cropped_faces, restored_faces, restored_img = self.restorer.enhance(
113
- image,
114
- has_aligned=False,
115
- only_center_face=False,
116
- paste_back=True,
117
- )
118
- print(f"βœ… [INFERENCE] Restored image size: {restored_img.shape}")
119
- return restored_img
120
-
121
- # ----------------------------------------------------------
122
- # Postprocess
123
- # ----------------------------------------------------------
124
- def postprocess(self, restored_img):
125
- print("πŸ“€ [POSTPROCESS] Encoding restored image...")
126
- if isinstance(restored_img, torch.Tensor):
127
- restored_img = restored_img.detach().cpu().numpy()
128
-
129
- # Convert numpy to PIL if needed
130
- if not isinstance(restored_img, Image.Image):
131
- restored_img = Image.fromarray(restored_img[..., ::-1]) # BGR -> RGB
132
-
133
- buf = io.BytesIO()
134
- restored_img.save(buf, format="PNG")
135
- encoded = base64.b64encode(buf.getvalue()).decode("utf-8")
136
- print("βœ… [POSTPROCESS] Image encoding complete.")
137
- return {"image": encoded}
138
 
 
 
 
 
1
  import base64
2
+ import io
3
+ import json
4
+ import logging
5
+ import os
6
+ from typing import Any, Dict
7
+
8
+ import numpy as np
9
  import requests
10
  from PIL import Image
11
+
12
  from gfpgan import GFPGANer
13
  from realesrgan import RealESRGANer
14
  from basicsr.archs.rrdbnet_arch import RRDBNet
15
 
16
+ # -----------------------------------------------------------------------------
17
+ # Logging setup
18
+ # -----------------------------------------------------------------------------
19
+ logging.basicConfig(level=logging.DEBUG)
20
+ logger = logging.getLogger(__name__)
21
+ logger.setLevel(logging.DEBUG)
22
+
23
+ # -----------------------------------------------------------------------------
24
+ # Model paths and URLs
25
+ # -----------------------------------------------------------------------------
26
+ REPO_DIR = os.environ.get("HF_HOME", "/repository")
27
+
28
+ GFPGAN_WEIGHTS_PATH = os.path.join(REPO_DIR, "GFPGANv1.4.pth")
29
+ REAL_ESRGAN_WEIGHTS_PATH = os.path.join(REPO_DIR, "realesr-general-x4v3.pth")
30
+
31
+ GFPGAN_URLS = [
32
+ # βœ… working file (GFPGANv1.4.pth is hosted under v1.3.0 release tag)
33
+ "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
34
+ "https://github.com/TencentARC/GFPGAN/releases/download/v1.4.0/GFPGANv1.4.pth",
35
+ ]
36
+
37
+ # βœ… updated Real-ESRGAN v0.2.5.0 URLs
38
+ REAL_ESRGAN_URLS = [
39
+ "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
40
+ "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
41
+ ]
42
+
43
+ # -----------------------------------------------------------------------------
44
+ # Helpers
45
+ # -----------------------------------------------------------------------------
46
+ def _ensure_file(path: str, urls) -> None:
47
+ if os.path.exists(path) and os.path.getsize(path) > 0:
48
+ logger.debug(f"βœ… File exists: {path}")
49
+ return
50
+ os.makedirs(os.path.dirname(path), exist_ok=True)
51
+ last_err = None
52
+ for u in urls:
53
+ try:
54
+ logger.debug(f"⬇️ Downloading {u}")
55
+ r = requests.get(u, timeout=60)
56
  r.raise_for_status()
57
+ with open(path, "wb") as f:
58
  f.write(r.content)
59
+ logger.debug(f"βœ… Saved to {path}")
60
+ return
61
+ except Exception as e:
62
+ last_err = e
63
+ logger.warning(f"⚠️ Failed from {u}: {e}")
64
+ raise RuntimeError(f"❌ Could not download required file: {last_err}")
65
+
66
+ def _to_bgr(image_bytes: bytes) -> np.ndarray:
67
+ pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
68
+ return np.array(pil)[:, :, ::-1].copy()
69
+
70
+ def _encode_bgr(bgr_img: np.ndarray) -> str:
71
+ import cv2
72
+ success, buf = cv2.imencode(".png", bgr_img)
73
+ if not success:
74
+ raise RuntimeError("Failed to encode image.")
75
+ return base64.b64encode(buf.tobytes()).decode("utf-8")
76
+
77
+ # -----------------------------------------------------------------------------
78
+ # EndpointHandler
79
+ # -----------------------------------------------------------------------------
80
+ class EndpointHandler:
81
+ """
82
+ Custom handler for GFPGAN v1.4 + Real-ESRGAN (realesr-general-x4v3.pth)
83
+ Emulates the behavior of the official Gradio demo.
84
+ """
85
+
86
+ def __init__(self, path: str = REPO_DIR):
87
+ logger.debug("πŸš€ [INIT] Starting GFPGAN + Real-ESRGAN handler...")
88
+ logger.debug(f"πŸ“‚ Repository path: {path}")
89
+
90
+ # 1️⃣ Ensure model weights
91
+ _ensure_file(GFPGAN_WEIGHTS_PATH, GFPGAN_URLS)
92
+ _ensure_file(REAL_ESRGAN_WEIGHTS_PATH, REAL_ESRGAN_URLS)
93
+
94
+ # 2️⃣ Device setup
95
+ import torch
96
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
97
+ self.half = torch.cuda.is_available()
98
+ logger.debug(f"🧠 Device: {self.device}, half precision: {self.half}")
99
+
100
+ # 3️⃣ Build Real-ESRGAN upsampler (x4)
101
+ logger.debug("🧩 Initializing Real-ESRGAN background upsampler...")
102
+ rrdb = RRDBNet(
103
+ num_in_ch=3,
104
+ num_out_ch=3,
105
+ num_feat=64,
106
+ num_block=23,
107
+ num_grow_ch=32,
108
+ scale=4,
109
  )
110
 
111
  self.bg_upsampler = RealESRGANer(
112
+ scale=4,
113
+ model_path=REAL_ESRGAN_WEIGHTS_PATH,
114
+ model=rrdb,
115
+ tile=0,
116
  tile_pad=10,
117
  pre_pad=0,
118
+ half=self.half,
119
+ device=self.device,
120
  )
121
+ logger.debug("βœ… Real-ESRGAN upsampler ready (x4).")
122
 
123
+ # 4️⃣ Build GFPGAN restorer (v1.4)
124
+ logger.debug("🧩 Initializing GFPGAN v1.4 restorer...")
 
 
125
  self.restorer = GFPGANer(
126
+ model_path=GFPGAN_WEIGHTS_PATH,
127
  upscale=2,
128
  arch="clean",
129
  channel_multiplier=2,
130
  bg_upsampler=self.bg_upsampler,
131
+ device=self.device,
132
  )
133
+ logger.debug("βœ… GFPGAN v1.4 initialized.")
 
 
 
 
 
 
 
134
 
135
+ # -------------------------------------------------------------------------
136
+ # Inference
137
+ # -------------------------------------------------------------------------
138
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
139
  try:
140
+ logger.debug(f"πŸŒ€ Received data type: {type(data)}")
141
+
142
+ # --- Parse input ---
143
+ image_bytes = None
144
+ parameters = {}
145
+
146
+ if isinstance(data, (bytes, bytearray)):
147
+ image_bytes = bytes(data)
148
+ elif isinstance(data, dict):
149
+ b64 = data.get("inputs") or data.get("image")
150
+ if b64:
151
+ image_bytes = base64.b64decode(b64)
152
+ parameters = data.get("parameters") or {}
153
+ elif isinstance(data, str):
154
+ try:
155
+ parsed = json.loads(data)
156
+ b64 = parsed.get("inputs") or parsed.get("image")
157
+ if b64:
158
+ image_bytes = base64.b64decode(b64)
159
+ parameters = parsed.get("parameters") or {}
160
+ except Exception as e:
161
+ logger.warning(f"⚠️ JSON parse error: {e}")
162
+
163
+ if not image_bytes:
164
+ return {"error": "Unsupported input format β€” expected bytes or base64 data"}
165
+
166
+ scale = int(parameters.get("scale", 2))
167
+ logger.debug(f"πŸ”§ Using scale factor: {scale}")
168
+
169
+ # Convert to BGR
170
+ bgr_input = _to_bgr(image_bytes)
171
+ logger.debug(f"πŸ“ Input image shape: {bgr_input.shape}")
172
+
173
+ # Enhance
174
+ logger.debug("✨ Running GFPGAN restoration...")
175
+ _, _, restored_img = self.restorer.enhance(
176
+ bgr_input, has_aligned=False, only_center_face=False, paste_back=True
177
+ )
178
+
179
+ if restored_img is None:
180
+ return {"error": "Restoration failed β€” no output."}
181
+
182
+ b64_img = _encode_bgr(restored_img)
183
+ logger.debug("βœ… Restoration complete.")
184
+ return {"image": b64_img}
185
+
186
  except Exception as e:
187
+ logger.exception("πŸ”₯ Inference error")
188
+ return {"error": f"{type(e).__name__}: {e}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189