mastari commited on
Commit
427d785
Β·
1 Parent(s): 6cbfe6a
Files changed (1) hide show
  1. handler.py +55 -12
handler.py CHANGED
@@ -1,46 +1,89 @@
1
  from transformers import pipeline
2
  from PIL import Image
3
- import io, base64, requests
4
 
5
  class EndpointHandler:
6
  def __init__(self, model_dir: str = "", **kwargs):
7
- print("πŸ”Ή Loading Swin2SR model ...")
8
- # Load the model once when the container starts
9
- self.model = pipeline(task="image-to-image", model="sergeipetrov/swin2SR-classical-sr-x2-64")
 
 
 
 
 
10
 
11
  def __call__(self, data):
 
 
 
 
12
  try:
13
  image_input = data.get("inputs")
14
- if not image_input:
15
- return {"error": "Missing 'inputs' field"}
 
 
 
 
16
 
17
- # Accept either URL or base64
18
  if isinstance(image_input, str) and image_input.startswith("http"):
 
19
  image = Image.open(requests.get(image_input, stream=True).raw).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  else:
21
- image_bytes = base64.b64decode(image_input)
22
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
 
23
 
24
  # Run inference
 
25
  output = self.model(image)
 
26
 
27
- # πŸ”Έ If the pipeline returns a list or dict, extract the image
28
  if isinstance(output, (list, tuple)):
 
29
  output = output[0]
30
  elif isinstance(output, dict) and "image" in output:
 
31
  output = output["image"]
32
 
33
- # πŸ”Έ Ensure output is a PIL Image
34
  if not isinstance(output, Image.Image):
35
- raise TypeError(f"Unexpected model output type: {type(output)}")
 
 
36
 
37
  # Encode to base64 for API response
 
38
  buffer = io.BytesIO()
39
  output.save(buffer, format="PNG")
40
  encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
41
 
 
42
  return {"image_base64": encoded}
43
 
44
  except Exception as e:
 
 
45
  return {"error": f"Inference failed: {str(e)}"}
46
 
 
1
  from transformers import pipeline
2
  from PIL import Image
3
+ import io, base64, requests, sys, traceback
4
 
5
  class EndpointHandler:
6
  def __init__(self, model_dir: str = "", **kwargs):
7
+ print("πŸ”Ή [INIT] Loading Swin2SR model ...")
8
+ try:
9
+ self.model = pipeline(task="image-to-image", model="sergeipetrov/swin2SR-classical-sr-x2-64")
10
+ print("βœ… [INIT] Model loaded successfully")
11
+ except Exception as e:
12
+ print("❌ [INIT] Model load failed:", str(e))
13
+ traceback.print_exc(file=sys.stdout)
14
+ raise e
15
 
16
  def __call__(self, data):
17
+ print("\n🟒 [CALL] Received request in handler")
18
+ print(f"πŸ”Ή [DEBUG] Raw data type: {type(data)}")
19
+ print(f"πŸ”Ή [DEBUG] Raw data keys: {list(data.keys()) if isinstance(data, dict) else 'N/A'}")
20
+
21
  try:
22
  image_input = data.get("inputs")
23
+ print(f"πŸ”Ή [DEBUG] image_input type: {type(image_input)}")
24
+
25
+ # Handle nested dicts (double-wrapped inputs)
26
+ if isinstance(image_input, dict) and "inputs" in image_input:
27
+ print("⚠️ [DEBUG] Nested 'inputs' dict detected β€” unwrapping...")
28
+ image_input = image_input["inputs"]
29
 
30
+ # Case 1: URL input
31
  if isinstance(image_input, str) and image_input.startswith("http"):
32
+ print(f"🌐 [INFO] Fetching image from URL: {image_input[:60]}...")
33
  image = Image.open(requests.get(image_input, stream=True).raw).convert("RGB")
34
+
35
+ # Case 2: Base64-encoded string
36
+ elif isinstance(image_input, str):
37
+ print(f"🧬 [INFO] Detected base64 string (len={len(image_input)})")
38
+ try:
39
+ image_bytes = base64.b64decode(image_input)
40
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
41
+ except Exception as e:
42
+ print("❌ [ERROR] Base64 decode failed:", str(e))
43
+ traceback.print_exc(file=sys.stdout)
44
+ return {"error": f"Failed to decode base64 image: {str(e)}"}
45
+
46
+ # Case 3: Raw bytes
47
+ elif isinstance(image_input, (bytes, bytearray)):
48
+ print(f"πŸ“¦ [INFO] Detected raw bytes input (len={len(image_input)})")
49
+ image = Image.open(io.BytesIO(image_input)).convert("RGB")
50
+
51
  else:
52
+ print(f"⚠️ [WARN] Unsupported input type: {type(image_input)}")
53
+ return {"error": f"Unsupported input type: {type(image_input)}"}
54
+
55
+ print("βœ… [INFO] Image successfully loaded and converted to RGB")
56
 
57
  # Run inference
58
+ print("πŸš€ [INFER] Running Swin2SR model inference...")
59
  output = self.model(image)
60
+ print("βœ… [INFER] Inference complete")
61
 
62
+ # Normalize output format
63
  if isinstance(output, (list, tuple)):
64
+ print("πŸ”„ [DEBUG] Output is list/tuple β€” taking first element")
65
  output = output[0]
66
  elif isinstance(output, dict) and "image" in output:
67
+ print("πŸ”„ [DEBUG] Output is dict with 'image' key")
68
  output = output["image"]
69
 
70
+ # Validate output type
71
  if not isinstance(output, Image.Image):
72
+ msg = f"Unexpected model output type: {type(output)}"
73
+ print("❌ [ERROR]", msg)
74
+ return {"error": msg}
75
 
76
  # Encode to base64 for API response
77
+ print("πŸ’Ύ [ENCODE] Encoding result image to base64...")
78
  buffer = io.BytesIO()
79
  output.save(buffer, format="PNG")
80
  encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
81
 
82
+ print("βœ… [RETURN] Returning base64-encoded image")
83
  return {"image_base64": encoded}
84
 
85
  except Exception as e:
86
+ print("❌ [FATAL] Inference failed with exception:")
87
+ traceback.print_exc(file=sys.stdout)
88
  return {"error": f"Inference failed: {str(e)}"}
89