swin2sr-x2-handler / handler.py
mastari's picture
fix
427d785
from transformers import pipeline
from PIL import Image
import io, base64, requests, sys, traceback
class EndpointHandler:
def __init__(self, model_dir: str = "", **kwargs):
print("πŸ”Ή [INIT] Loading Swin2SR model ...")
try:
self.model = pipeline(task="image-to-image", model="sergeipetrov/swin2SR-classical-sr-x2-64")
print("βœ… [INIT] Model loaded successfully")
except Exception as e:
print("❌ [INIT] Model load failed:", str(e))
traceback.print_exc(file=sys.stdout)
raise e
def __call__(self, data):
print("\n🟒 [CALL] Received request in handler")
print(f"πŸ”Ή [DEBUG] Raw data type: {type(data)}")
print(f"πŸ”Ή [DEBUG] Raw data keys: {list(data.keys()) if isinstance(data, dict) else 'N/A'}")
try:
image_input = data.get("inputs")
print(f"πŸ”Ή [DEBUG] image_input type: {type(image_input)}")
# Handle nested dicts (double-wrapped inputs)
if isinstance(image_input, dict) and "inputs" in image_input:
print("⚠️ [DEBUG] Nested 'inputs' dict detected β€” unwrapping...")
image_input = image_input["inputs"]
# Case 1: URL input
if isinstance(image_input, str) and image_input.startswith("http"):
print(f"🌐 [INFO] Fetching image from URL: {image_input[:60]}...")
image = Image.open(requests.get(image_input, stream=True).raw).convert("RGB")
# Case 2: Base64-encoded string
elif isinstance(image_input, str):
print(f"🧬 [INFO] Detected base64 string (len={len(image_input)})")
try:
image_bytes = base64.b64decode(image_input)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
print("❌ [ERROR] Base64 decode failed:", str(e))
traceback.print_exc(file=sys.stdout)
return {"error": f"Failed to decode base64 image: {str(e)}"}
# Case 3: Raw bytes
elif isinstance(image_input, (bytes, bytearray)):
print(f"πŸ“¦ [INFO] Detected raw bytes input (len={len(image_input)})")
image = Image.open(io.BytesIO(image_input)).convert("RGB")
else:
print(f"⚠️ [WARN] Unsupported input type: {type(image_input)}")
return {"error": f"Unsupported input type: {type(image_input)}"}
print("βœ… [INFO] Image successfully loaded and converted to RGB")
# Run inference
print("πŸš€ [INFER] Running Swin2SR model inference...")
output = self.model(image)
print("βœ… [INFER] Inference complete")
# Normalize output format
if isinstance(output, (list, tuple)):
print("πŸ”„ [DEBUG] Output is list/tuple β€” taking first element")
output = output[0]
elif isinstance(output, dict) and "image" in output:
print("πŸ”„ [DEBUG] Output is dict with 'image' key")
output = output["image"]
# Validate output type
if not isinstance(output, Image.Image):
msg = f"Unexpected model output type: {type(output)}"
print("❌ [ERROR]", msg)
return {"error": msg}
# Encode to base64 for API response
print("πŸ’Ύ [ENCODE] Encoding result image to base64...")
buffer = io.BytesIO()
output.save(buffer, format="PNG")
encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
print("βœ… [RETURN] Returning base64-encoded image")
return {"image_base64": encoded}
except Exception as e:
print("❌ [FATAL] Inference failed with exception:")
traceback.print_exc(file=sys.stdout)
return {"error": f"Inference failed: {str(e)}"}