mastari commited on
Commit
4984743
Β·
1 Parent(s): 3e3c736
Files changed (1) hide show
  1. handler.py +47 -26
handler.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import io
3
  import torch
4
  import logging
 
5
  import requests
6
  import numpy as np
7
  import cv2
@@ -20,7 +21,7 @@ class EndpointHandler:
20
  self.half = self.device == "cuda"
21
  self.path = path
22
 
23
- # Model URLs
24
  self.gfpgan_model_url = (
25
  "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
26
  )
@@ -28,31 +29,30 @@ class EndpointHandler:
28
  "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
29
  )
30
 
31
- # Local paths
32
  self.gfpgan_model_path = os.path.join(path, "GFPGANv1.4.pth")
33
  self.realesr_model_path = os.path.join(path, "realesr-general-x4v3.pth")
34
 
35
- # Lazy init
36
  self.bg_upsampler = None
37
  self.restorer = None
38
 
39
- # Ensure model files exist
40
  self._ensure_model(self.gfpgan_model_url, self.gfpgan_model_path)
41
  self._ensure_model(self.realesr_model_url, self.realesr_model_path)
42
-
43
  logger.info(f"🧠 Device: {self.device}, half precision: {self.half}")
44
 
45
- def _ensure_model(self, url, local_path):
46
- if not os.path.exists(local_path):
47
- logger.info(f"⬇️ Downloading {url}")
48
  r = requests.get(url, timeout=60)
49
  r.raise_for_status()
50
- with open(local_path, "wb") as f:
51
  f.write(r.content)
52
- logger.info(f"βœ… Saved to {local_path}")
 
 
53
 
54
  def _init_models(self):
55
- """Lazy-load GFPGAN and Real-ESRGAN models."""
56
  if self.bg_upsampler is None:
57
  logger.info("🧩 Initializing Real-ESRGAN upsampler...")
58
  model = SRVGGNetCompact(
@@ -81,35 +81,56 @@ class EndpointHandler:
81
  )
82
  logger.info("βœ… Models ready!")
83
 
84
- def __call__(self, data):
85
- """Restore a face photo."""
86
- self._init_models()
87
-
88
  if isinstance(data, dict) and "inputs" in data:
89
  data = data["inputs"]
90
 
91
- # Load image
92
  if isinstance(data, (bytes, bytearray)):
93
- image = Image.open(io.BytesIO(data)).convert("RGB")
94
- elif isinstance(data, str) and data.startswith("http"):
95
- resp = requests.get(data)
96
- image = Image.open(io.BytesIO(resp.content)).convert("RGB")
97
- elif isinstance(data, Image.Image):
98
- image = data
99
- else:
100
- raise ValueError("Unsupported input type")
 
 
 
 
 
 
 
 
 
 
 
101
 
 
 
 
 
 
 
102
  input_img = np.array(image, dtype=np.uint8)
103
 
 
 
104
  cropped_faces, restored_faces, restored_img = self.restorer.enhance(
105
  input_img, has_aligned=False, only_center_face=False, paste_back=True
106
  )
107
 
 
 
 
108
  _, buffer = cv2.imencode(".jpg", restored_img)
109
- output_bytes = io.BytesIO(buffer.tobytes())
110
 
 
111
  return {
112
- "output": output_bytes,
 
113
  "info": "Restored with GFPGAN v1.4 + Real-ESRGAN x4v3"
114
  }
115
 
 
2
  import io
3
  import torch
4
  import logging
5
+ import base64
6
  import requests
7
  import numpy as np
8
  import cv2
 
21
  self.half = self.device == "cuda"
22
  self.path = path
23
 
24
+ # URLs
25
  self.gfpgan_model_url = (
26
  "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
27
  )
 
29
  "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
30
  )
31
 
32
+ # Local model paths
33
  self.gfpgan_model_path = os.path.join(path, "GFPGANv1.4.pth")
34
  self.realesr_model_path = os.path.join(path, "realesr-general-x4v3.pth")
35
 
 
36
  self.bg_upsampler = None
37
  self.restorer = None
38
 
 
39
  self._ensure_model(self.gfpgan_model_url, self.gfpgan_model_path)
40
  self._ensure_model(self.realesr_model_url, self.realesr_model_path)
 
41
  logger.info(f"🧠 Device: {self.device}, half precision: {self.half}")
42
 
43
+ def _ensure_model(self, url, path):
44
+ if not os.path.exists(path):
45
+ logger.info(f"⬇️ Downloading model from {url}")
46
  r = requests.get(url, timeout=60)
47
  r.raise_for_status()
48
+ with open(path, "wb") as f:
49
  f.write(r.content)
50
+ logger.info(f"βœ… Model saved to {path}")
51
+ else:
52
+ logger.info(f"πŸ“ Found cached model: {path}")
53
 
54
  def _init_models(self):
55
+ """Lazy-load models"""
56
  if self.bg_upsampler is None:
57
  logger.info("🧩 Initializing Real-ESRGAN upsampler...")
58
  model = SRVGGNetCompact(
 
81
  )
82
  logger.info("βœ… Models ready!")
83
 
84
+ def _load_image(self, data):
85
+ """Handle different input formats."""
 
 
86
  if isinstance(data, dict) and "inputs" in data:
87
  data = data["inputs"]
88
 
 
89
  if isinstance(data, (bytes, bytearray)):
90
+ logger.info("πŸ“¦ Received raw bytes input")
91
+ return Image.open(io.BytesIO(data)).convert("RGB")
92
+
93
+ if isinstance(data, str):
94
+ if data.startswith("http"):
95
+ logger.info(f"🌐 Downloading image from URL: {data}")
96
+ resp = requests.get(data)
97
+ return Image.open(io.BytesIO(resp.content)).convert("RGB")
98
+ else:
99
+ # assume base64
100
+ logger.info("🧬 Decoding base64 image input")
101
+ try:
102
+ decoded = base64.b64decode(data)
103
+ return Image.open(io.BytesIO(decoded)).convert("RGB")
104
+ except Exception as e:
105
+ logger.error(f"❌ Failed to decode base64: {e}")
106
+ raise ValueError("Invalid base64 image input")
107
+
108
+ raise ValueError("Unsupported input type")
109
 
110
+ def __call__(self, data):
111
+ self._init_models()
112
+ logger.info("βš™οΈ Starting inference...")
113
+
114
+ # Load input
115
+ image = self._load_image(data)
116
  input_img = np.array(image, dtype=np.uint8)
117
 
118
+ logger.info(f"πŸ“ Input image shape: {input_img.shape}")
119
+
120
  cropped_faces, restored_faces, restored_img = self.restorer.enhance(
121
  input_img, has_aligned=False, only_center_face=False, paste_back=True
122
  )
123
 
124
+ logger.info("πŸ–ΌοΈ Restoration complete, encoding output...")
125
+
126
+ # Encode result as base64
127
  _, buffer = cv2.imencode(".jpg", restored_img)
128
+ b64_output = base64.b64encode(buffer).decode("utf-8")
129
 
130
+ logger.info("βœ… Returning base64 image JSON")
131
  return {
132
+ "image": b64_output,
133
+ "status": "success",
134
  "info": "Restored with GFPGAN v1.4 + Real-ESRGAN x4v3"
135
  }
136