mastari commited on
Commit
eb396d8
Β·
1 Parent(s): dd6cc98

Add GFPGAN + RealESRGAN hybrid handler

Browse files
Files changed (2) hide show
  1. handler.py +119 -115
  2. requirements.txt +1 -0
handler.py CHANGED
@@ -1,134 +1,138 @@
1
- import io
2
  import os
3
- import cv2
4
  import torch
5
  import base64
6
- import logging
7
  import requests
8
- import numpy as np
9
  from PIL import Image
10
  from gfpgan import GFPGANer
 
 
11
 
12
- # ======================================================
13
- # LOGGING CONFIGURATION
14
- # ======================================================
15
- logging.basicConfig(level=logging.DEBUG)
16
- logger = logging.getLogger(__name__)
17
- logger.setLevel(logging.DEBUG)
18
- logger.debug("πŸ“¦ [INIT] Importing GFPGAN handler module...")
19
-
20
- # ======================================================
21
- # GFPGAN MODEL URL
22
- # ======================================================
23
- MODEL_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
24
- MODEL_NAME = "GFPGANv1.4.pth"
25
 
26
-
27
- # ======================================================
28
- # ENDPOINT HANDLER
29
- # ======================================================
30
  class EndpointHandler:
31
  def __init__(self, path="."):
32
- logger.debug("πŸš€ [INIT] Starting GFPGAN EndpointHandler initialization...")
33
- logger.debug(f"πŸ“ Working directory: {os.getcwd()}")
34
- logger.debug(f"πŸ“‚ Handler path argument: {path}")
35
-
36
- model_path = os.path.join(path, MODEL_NAME)
37
- logger.debug(f"πŸ”— [MODEL] Expected model path: {model_path}")
38
-
39
- # ------------------------------------------------------
40
- # Download model if missing
41
- # ------------------------------------------------------
42
- if not os.path.exists(model_path):
43
- try:
44
- logger.debug(f"πŸ“₯ [DOWNLOAD] Model not found locally β€” fetching from {MODEL_URL}")
45
- r = requests.get(MODEL_URL, stream=True)
46
- r.raise_for_status()
47
- with open(model_path, "wb") as f:
48
- for chunk in r.iter_content(chunk_size=8192):
49
- if chunk:
50
- f.write(chunk)
51
- logger.debug("βœ… [MODEL] Downloaded GFPGAN weights successfully.")
52
- except Exception as e:
53
- logger.error(f"πŸ’₯ [ERROR] Failed to download GFPGAN weights: {e}")
54
- raise
55
-
56
- # ------------------------------------------------------
57
- # Initialize GFPGANer (same as official Gradio demo)
58
- # ------------------------------------------------------
59
- try:
60
- logger.debug("🧠 [MODEL] Initializing GFPGANer (upscale=2, arch='clean')...")
61
- self.restorer = GFPGANer(
62
- model_path=model_path,
63
- upscale=2, # Rescaling factor = 2
64
- arch="clean",
65
- channel_multiplier=2,
66
- bg_upsampler=None
67
- )
68
- logger.debug("βœ… [MODEL] GFPGAN model initialized successfully.")
69
- except Exception as e:
70
- logger.error(f"πŸ’₯ [ERROR] Model initialization failed: {e}")
71
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- # ======================================================
74
- # INFERENCE CALL
75
- # ======================================================
76
  def __call__(self, data):
77
- logger.debug("βš™οΈ [INFER] Starting inference...")
78
- logger.debug(f"πŸ“₯ Incoming data type: {type(data)}")
79
 
80
- # ------------------------------------------------------
81
- # Handle both JSON base64 and raw bytes
82
- # ------------------------------------------------------
83
  try:
84
- if isinstance(data, dict) and "inputs" in data:
85
- logger.debug("πŸ“¦ Detected JSON base64 input")
86
- image_bytes = base64.b64decode(data["inputs"])
87
- elif isinstance(data, (bytes, bytearray)):
88
- logger.debug("πŸ“¦ Detected raw bytes input")
89
- image_bytes = data
90
- else:
91
- raise ValueError("Unsupported input format β€” expected bytes or base64 data")
92
-
93
- logger.debug(f"🧾 [BYTES] Received {len(image_bytes)} bytes")
94
  except Exception as e:
95
- logger.error(f"πŸ’₯ [ERROR] Input parsing failed: {e}")
96
- return {"error": f"Invalid input: {e}"}
97
 
98
- # ------------------------------------------------------
99
- # Decode image
100
- # ------------------------------------------------------
101
- try:
102
- img_np = np.array(Image.open(io.BytesIO(image_bytes)).convert("RGB"))
103
- logger.debug(f"πŸ–ΌοΈ [IMAGE] Loaded image of shape: {img_np.shape}")
104
- except Exception as e:
105
- logger.error(f"πŸ’₯ [ERROR] Failed to load image: {e}")
106
- return {"error": f"Image loading failed: {e}"}
107
 
108
- # ------------------------------------------------------
109
- # Run GFPGAN restoration
110
- # ------------------------------------------------------
111
- try:
112
- cropped_faces, restored_faces, restored_img = self.restorer.enhance(
113
- img_np,
114
- has_aligned=False,
115
- only_center_face=False,
116
- paste_back=True # Matches GFPGAN web demo
117
- )
118
- logger.debug("βœ… [RESTORE] Face restoration completed successfully.")
119
- except Exception as e:
120
- logger.error(f"πŸ’₯ [ERROR] GFPGAN enhancement failed: {e}")
121
- return {"error": f"Enhancement failed: {e}"}
122
 
123
- # ------------------------------------------------------
124
- # Encode result as base64 PNG
125
- # ------------------------------------------------------
126
- try:
127
- _, buffer = cv2.imencode(".png", restored_img[:, :, ::-1]) # BGR→RGB
128
- img_base64 = base64.b64encode(buffer).decode("utf-8")
129
- logger.debug("πŸ“€ [ENCODE] Encoded restored image successfully.")
130
- return {"image": img_base64}
131
- except Exception as e:
132
- logger.error(f"πŸ’₯ [ERROR] Failed to encode image: {e}")
133
- return {"error": f"Encoding failed: {e}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
 
 
1
  import os
2
+ import io
3
  import torch
4
  import base64
 
5
  import requests
 
6
  from PIL import Image
7
  from gfpgan import GFPGANer
8
+ from realesrgan import RealESRGANer
9
+ from basicsr.archs.rrdbnet_arch import RRDBNet
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
 
 
 
 
12
  class EndpointHandler:
13
  def __init__(self, path="."):
14
+ print("πŸš€ [INIT] Starting GFPGAN + RealESRGAN hybrid handler initialization...")
15
+ print(f"πŸ“‚ Working directory: {os.getcwd()}")
16
+ print(f"πŸ“ Handler path argument: {path}")
17
+
18
+ # ------------------------------
19
+ # Download GFPGAN v1.4 weights
20
+ # ------------------------------
21
+ self.model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.4.0/GFPGANv1.4.pth"
22
+ self.model_path = os.path.join(path, "GFPGANv1.4.pth")
23
+ if not os.path.exists(self.model_path):
24
+ print(f"πŸ“₯ [DOWNLOAD] Fetching GFPGAN v1.4 weights...")
25
+ r = requests.get(self.model_url)
26
+ r.raise_for_status()
27
+ with open(self.model_path, "wb") as f:
28
+ f.write(r.content)
29
+ print(f"βœ… [MODEL] Downloaded GFPGAN model to {self.model_path}")
30
+ else:
31
+ print("βœ… [MODEL] GFPGAN weights already exist locally.")
32
+
33
+ # ------------------------------
34
+ # Setup background upsampler (Real-ESRGAN)
35
+ # ------------------------------
36
+ print("🧠 [INIT] Setting up Real-ESRGAN background upsampler...")
37
+ rrdbnet = RRDBNet(
38
+ num_in_ch=3, num_out_ch=3,
39
+ num_feat=64, num_block=23,
40
+ num_grow_ch=32, scale=4
41
+ )
42
+
43
+ self.bg_upsampler = RealESRGANer(
44
+ scale=2,
45
+ model_path=None, # auto-download model weights
46
+ model=rrdbnet,
47
+ tile=400,
48
+ tile_pad=10,
49
+ pre_pad=0,
50
+ half=False,
51
+ device="cuda" if torch.cuda.is_available() else "cpu",
52
+ )
53
+ print("βœ… [INIT] Real-ESRGAN background upsampler ready.")
54
+
55
+ # ------------------------------
56
+ # Setup GFPGANer
57
+ # ------------------------------
58
+ print("🧠 [INIT] Setting up GFPGANer (v1.4)...")
59
+ self.restorer = GFPGANer(
60
+ model_path=self.model_path,
61
+ upscale=2,
62
+ arch="clean",
63
+ channel_multiplier=2,
64
+ bg_upsampler=self.bg_upsampler,
65
+ device="cuda" if torch.cuda.is_available() else "cpu",
66
+ )
67
+ print("βœ… [INIT DONE] GFPGAN + RealESRGAN hybrid handler ready.")
68
 
69
+ # ----------------------------------------------------------
70
+ # Main inference entry point
71
+ # ----------------------------------------------------------
72
  def __call__(self, data):
73
+ print("πŸ›°οΈ [CALL] Endpoint invoked!")
74
+ print(f"πŸ“¦ [CALL] Raw input type: {type(data)}")
75
 
 
 
 
76
  try:
77
+ image = self.preprocess(data)
78
+ print("🧩 [STEP] Image preprocessed successfully.")
79
+ restored = self.inference(image)
80
+ print("🎨 [STEP] Inference completed successfully.")
81
+ return self.postprocess(restored)
 
 
 
 
 
82
  except Exception as e:
83
+ print(f"πŸ’₯ [ERROR] Exception during call: {str(e)}")
84
+ return {"error": str(e)}
85
 
86
+ # ----------------------------------------------------------
87
+ # Preprocessing
88
+ # ----------------------------------------------------------
89
+ def preprocess(self, data):
90
+ print("πŸ”§ [PREPROCESS] Starting...")
91
+ if isinstance(data, (bytes, bytearray)):
92
+ print("πŸ–ΌοΈ [PREPROCESS] Raw bytes detected.")
93
+ return Image.open(io.BytesIO(data)).convert("RGB")
 
94
 
95
+ if isinstance(data, dict):
96
+ img_field = data.get("inputs") or data.get("image")
97
+ if isinstance(img_field, str):
98
+ print("🧬 [PREPROCESS] Base64 string detected.")
99
+ decoded = base64.b64decode(img_field)
100
+ return Image.open(io.BytesIO(decoded)).convert("RGB")
101
+ if isinstance(img_field, (bytes, bytearray)):
102
+ print("🧩 [PREPROCESS] Byte array detected.")
103
+ return Image.open(io.BytesIO(img_field)).convert("RGB")
 
 
 
 
 
104
 
105
+ raise ValueError("Unsupported input format β€” expected bytes or base64 data.")
106
+
107
+ # ----------------------------------------------------------
108
+ # Inference
109
+ # ----------------------------------------------------------
110
+ def inference(self, image):
111
+ print("βš™οΈ [INFERENCE] Running GFPGAN + RealESRGAN enhancement...")
112
+ cropped_faces, restored_faces, restored_img = self.restorer.enhance(
113
+ image,
114
+ has_aligned=False,
115
+ only_center_face=False,
116
+ paste_back=True,
117
+ )
118
+ print(f"βœ… [INFERENCE] Restored image size: {restored_img.shape}")
119
+ return restored_img
120
+
121
+ # ----------------------------------------------------------
122
+ # Postprocess
123
+ # ----------------------------------------------------------
124
+ def postprocess(self, restored_img):
125
+ print("πŸ“€ [POSTPROCESS] Encoding restored image...")
126
+ if isinstance(restored_img, torch.Tensor):
127
+ restored_img = restored_img.detach().cpu().numpy()
128
+
129
+ # Convert numpy to PIL if needed
130
+ if not isinstance(restored_img, Image.Image):
131
+ restored_img = Image.fromarray(restored_img[..., ::-1]) # BGR -> RGB
132
+
133
+ buf = io.BytesIO()
134
+ restored_img.save(buf, format="PNG")
135
+ encoded = base64.b64encode(buf.getvalue()).decode("utf-8")
136
+ print("βœ… [POSTPROCESS] Image encoding complete.")
137
+ return {"image": encoded}
138
 
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  torch==2.1.0
2
  torchvision==0.16.0
3
  gfpgan==1.3.8
 
4
  basicsr==1.4.2
5
  facexlib==0.3.0
6
  opencv-python
 
1
  torch==2.1.0
2
  torchvision==0.16.0
3
  gfpgan==1.3.8
4
+ realesrgan==0.3.0
5
  basicsr==1.4.2
6
  facexlib==0.3.0
7
  opencv-python