Fix: convert PIL to NumPy before ESRGAN inference
Browse files- handler.py +14 -11
handler.py
CHANGED
|
@@ -3,6 +3,7 @@ import io
|
|
| 3 |
import torch
|
| 4 |
import base64
|
| 5 |
import requests
|
|
|
|
| 6 |
from PIL import Image
|
| 7 |
from realesrgan import RealESRGANer
|
| 8 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
@@ -85,35 +86,28 @@ class EndpointHandler:
|
|
| 85 |
def preprocess(self, data):
|
| 86 |
print(f"🔧 [PREPROCESS] Type received: {type(data)}")
|
| 87 |
|
| 88 |
-
# 1️⃣ Hugging Face JSON-wrapped dict {"inputs": ...}
|
| 89 |
if isinstance(data, dict):
|
| 90 |
print("🧩 [PREPROCESS] Detected dict input.")
|
| 91 |
if "inputs" in data:
|
| 92 |
data = data["inputs"]
|
| 93 |
print(f"📨 [PREPROCESS] Found 'inputs' key: {type(data)}")
|
| 94 |
|
| 95 |
-
# 2️⃣ Direct PIL image object (the missing case!)
|
| 96 |
if isinstance(data, Image.Image):
|
| 97 |
-
print("🖼️ [PREPROCESS] Got PIL.Image.Image directly
|
| 98 |
return data.convert("RGB")
|
| 99 |
|
| 100 |
-
# 3️⃣ Raw bytes
|
| 101 |
if isinstance(data, (bytes, bytearray)):
|
| 102 |
print("🧾 [PREPROCESS] Treating input as raw bytes.")
|
| 103 |
return Image.open(io.BytesIO(data)).convert("RGB")
|
| 104 |
|
| 105 |
-
# 4️⃣ Base64 string
|
| 106 |
if isinstance(data, str):
|
| 107 |
print(f"🧾 [PREPROCESS] Treating input as base64 string, len={len(data)}")
|
| 108 |
decoded = base64.b64decode(data)
|
| 109 |
return Image.open(io.BytesIO(decoded)).convert("RGB")
|
| 110 |
|
| 111 |
-
# 5️⃣ List (rare HF wrapper case)
|
| 112 |
if isinstance(data, list) and len(data) > 0:
|
| 113 |
-
print("📚 [PREPROCESS] List input detected.")
|
| 114 |
item = data[0]
|
| 115 |
if isinstance(item, Image.Image):
|
| 116 |
-
print("📷 [PREPROCESS] List contains a PIL.Image.Image.")
|
| 117 |
return item.convert("RGB")
|
| 118 |
if isinstance(item, (bytes, bytearray)):
|
| 119 |
return Image.open(io.BytesIO(item)).convert("RGB")
|
|
@@ -128,9 +122,18 @@ class EndpointHandler:
|
|
| 128 |
def inference(self, image):
|
| 129 |
print("🎯 [INFERENCE] Running ESRGAN upscaling...")
|
| 130 |
print(f"📐 [INFERENCE] Input image size: {image.size}")
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
# ==========================================================
|
| 136 |
# POSTPROCESS
|
|
|
|
| 3 |
import torch
|
| 4 |
import base64
|
| 5 |
import requests
|
| 6 |
+
import numpy as np
|
| 7 |
from PIL import Image
|
| 8 |
from realesrgan import RealESRGANer
|
| 9 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
|
|
| 86 |
def preprocess(self, data):
|
| 87 |
print(f"🔧 [PREPROCESS] Type received: {type(data)}")
|
| 88 |
|
|
|
|
| 89 |
if isinstance(data, dict):
|
| 90 |
print("🧩 [PREPROCESS] Detected dict input.")
|
| 91 |
if "inputs" in data:
|
| 92 |
data = data["inputs"]
|
| 93 |
print(f"📨 [PREPROCESS] Found 'inputs' key: {type(data)}")
|
| 94 |
|
|
|
|
| 95 |
if isinstance(data, Image.Image):
|
| 96 |
+
print("🖼️ [PREPROCESS] Got PIL.Image.Image directly.")
|
| 97 |
return data.convert("RGB")
|
| 98 |
|
|
|
|
| 99 |
if isinstance(data, (bytes, bytearray)):
|
| 100 |
print("🧾 [PREPROCESS] Treating input as raw bytes.")
|
| 101 |
return Image.open(io.BytesIO(data)).convert("RGB")
|
| 102 |
|
|
|
|
| 103 |
if isinstance(data, str):
|
| 104 |
print(f"🧾 [PREPROCESS] Treating input as base64 string, len={len(data)}")
|
| 105 |
decoded = base64.b64decode(data)
|
| 106 |
return Image.open(io.BytesIO(decoded)).convert("RGB")
|
| 107 |
|
|
|
|
| 108 |
if isinstance(data, list) and len(data) > 0:
|
|
|
|
| 109 |
item = data[0]
|
| 110 |
if isinstance(item, Image.Image):
|
|
|
|
| 111 |
return item.convert("RGB")
|
| 112 |
if isinstance(item, (bytes, bytearray)):
|
| 113 |
return Image.open(io.BytesIO(item)).convert("RGB")
|
|
|
|
| 122 |
def inference(self, image):
|
| 123 |
print("🎯 [INFERENCE] Running ESRGAN upscaling...")
|
| 124 |
print(f"📐 [INFERENCE] Input image size: {image.size}")
|
| 125 |
+
|
| 126 |
+
# Convert PIL -> NumPy BGR for RealESRGAN
|
| 127 |
+
img_np = np.array(image)[:, :, ::-1] # RGB -> BGR
|
| 128 |
+
print(f"🔍 [INFERENCE] Converted to NumPy: shape={img_np.shape}, dtype={img_np.dtype}")
|
| 129 |
+
|
| 130 |
+
output, _ = self.upsampler.enhance(img_np, outscale=4)
|
| 131 |
+
print(f"✅ [INFERENCE] Output NumPy shape: {output.shape}")
|
| 132 |
+
|
| 133 |
+
# Convert back to PIL RGB
|
| 134 |
+
output_rgb = Image.fromarray(output[:, :, ::-1])
|
| 135 |
+
print(f"✅ [INFERENCE] Converted back to PIL: size={output_rgb.size}")
|
| 136 |
+
return output_rgb
|
| 137 |
|
| 138 |
# ==========================================================
|
| 139 |
# POSTPROCESS
|