mastari commited on
Commit
c8867e7
Β·
1 Parent(s): b51e8ba

Fix color hue and add RGB output conversion

Browse files
Files changed (1) hide show
  1. handler.py +19 -12
handler.py CHANGED
@@ -21,7 +21,7 @@ class EndpointHandler:
21
  self.half = self.device == "cuda"
22
  self.path = path
23
 
24
- # URLs
25
  self.gfpgan_model_url = (
26
  "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
27
  )
@@ -29,18 +29,21 @@ class EndpointHandler:
29
  "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
30
  )
31
 
32
- # Local model paths
33
  self.gfpgan_model_path = os.path.join(path, "GFPGANv1.4.pth")
34
  self.realesr_model_path = os.path.join(path, "realesr-general-x4v3.pth")
35
 
36
  self.bg_upsampler = None
37
  self.restorer = None
38
 
 
39
  self._ensure_model(self.gfpgan_model_url, self.gfpgan_model_path)
40
  self._ensure_model(self.realesr_model_url, self.realesr_model_path)
 
41
  logger.info(f"🧠 Device: {self.device}, half precision: {self.half}")
42
 
43
  def _ensure_model(self, url, path):
 
44
  if not os.path.exists(path):
45
  logger.info(f"⬇️ Downloading model from {url}")
46
  r = requests.get(url, timeout=60)
@@ -52,7 +55,7 @@ class EndpointHandler:
52
  logger.info(f"πŸ“ Found cached model: {path}")
53
 
54
  def _init_models(self):
55
- """Lazy-load models"""
56
  if self.bg_upsampler is None:
57
  logger.info("🧩 Initializing Real-ESRGAN upsampler...")
58
  model = SRVGGNetCompact(
@@ -82,7 +85,7 @@ class EndpointHandler:
82
  logger.info("βœ… Models ready!")
83
 
84
  def _load_image(self, data):
85
- """Handle different input formats."""
86
  if isinstance(data, dict) and "inputs" in data:
87
  data = data["inputs"]
88
 
@@ -96,7 +99,7 @@ class EndpointHandler:
96
  resp = requests.get(data)
97
  return Image.open(io.BytesIO(resp.content)).convert("RGB")
98
  else:
99
- # assume base64
100
  logger.info("🧬 Decoding base64 image input")
101
  try:
102
  decoded = base64.b64decode(data)
@@ -108,30 +111,34 @@ class EndpointHandler:
108
  raise ValueError("Unsupported input type")
109
 
110
  def __call__(self, data):
 
111
  self._init_models()
112
- logger.info("βš™οΈ Starting inference...")
113
 
114
  # Load input
115
  image = self._load_image(data)
116
  input_img = np.array(image, dtype=np.uint8)
117
-
118
  logger.info(f"πŸ“ Input image shape: {input_img.shape}")
119
 
 
120
  cropped_faces, restored_faces, restored_img = self.restorer.enhance(
121
  input_img, has_aligned=False, only_center_face=False, paste_back=True
122
  )
123
 
124
- logger.info("πŸ–ΌοΈ Restoration complete, encoding output...")
125
 
126
- # Encode result as base64
127
  restored_img_rgb = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
128
- _, buffer = cv2.imencode(".jpg", restored_img)
 
 
 
129
  b64_output = base64.b64encode(buffer).decode("utf-8")
130
 
131
- logger.info("βœ… Returning base64 image JSON")
 
132
  return {
133
  "image": b64_output,
134
  "status": "success",
135
- "info": "Restored with GFPGAN v1.4 + Real-ESRGAN x4v3"
136
  }
137
 
 
21
  self.half = self.device == "cuda"
22
  self.path = path
23
 
24
+ # Model URLs (GFPGAN + RealESRGAN)
25
  self.gfpgan_model_url = (
26
  "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
27
  )
 
29
  "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
30
  )
31
 
32
+ # Local cache paths
33
  self.gfpgan_model_path = os.path.join(path, "GFPGANv1.4.pth")
34
  self.realesr_model_path = os.path.join(path, "realesr-general-x4v3.pth")
35
 
36
  self.bg_upsampler = None
37
  self.restorer = None
38
 
39
+ # Ensure model weights exist
40
  self._ensure_model(self.gfpgan_model_url, self.gfpgan_model_path)
41
  self._ensure_model(self.realesr_model_url, self.realesr_model_path)
42
+
43
  logger.info(f"🧠 Device: {self.device}, half precision: {self.half}")
44
 
45
  def _ensure_model(self, url, path):
46
+ """Download model if missing."""
47
  if not os.path.exists(path):
48
  logger.info(f"⬇️ Downloading model from {url}")
49
  r = requests.get(url, timeout=60)
 
55
  logger.info(f"πŸ“ Found cached model: {path}")
56
 
57
  def _init_models(self):
58
+ """Lazy-load ESRGAN + GFPGAN models."""
59
  if self.bg_upsampler is None:
60
  logger.info("🧩 Initializing Real-ESRGAN upsampler...")
61
  model = SRVGGNetCompact(
 
85
  logger.info("βœ… Models ready!")
86
 
87
  def _load_image(self, data):
88
+ """Accept base64, raw bytes, or URL and return PIL image."""
89
  if isinstance(data, dict) and "inputs" in data:
90
  data = data["inputs"]
91
 
 
99
  resp = requests.get(data)
100
  return Image.open(io.BytesIO(resp.content)).convert("RGB")
101
  else:
102
+ # Base64
103
  logger.info("🧬 Decoding base64 image input")
104
  try:
105
  decoded = base64.b64decode(data)
 
111
  raise ValueError("Unsupported input type")
112
 
113
  def __call__(self, data):
114
+ logger.info("βš™οΈ Starting GFPGAN inference pipeline...")
115
  self._init_models()
 
116
 
117
  # Load input
118
  image = self._load_image(data)
119
  input_img = np.array(image, dtype=np.uint8)
 
120
  logger.info(f"πŸ“ Input image shape: {input_img.shape}")
121
 
122
+ # Restore face(s)
123
  cropped_faces, restored_faces, restored_img = self.restorer.enhance(
124
  input_img, has_aligned=False, only_center_face=False, paste_back=True
125
  )
126
 
127
+ logger.info("πŸ–ΌοΈ Restoration complete, preparing output...")
128
 
129
+ # βœ… Convert color from BGR β†’ RGB (fix hue issue)
130
  restored_img_rgb = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
131
+ restored_img_rgb = np.clip(restored_img_rgb, 0, 255).astype(np.uint8)
132
+
133
+ # βœ… Encode output as base64 string for JSON
134
+ _, buffer = cv2.imencode(".jpg", restored_img_rgb)
135
  b64_output = base64.b64encode(buffer).decode("utf-8")
136
 
137
+ logger.info("βœ… Returning base64-encoded image JSON response")
138
+
139
  return {
140
  "image": b64_output,
141
  "status": "success",
142
+ "info": "Restored with GFPGAN v1.4 + Real-ESRGAN x4v3 (RGB fixed)"
143
  }
144