mastari commited on
Commit
698cda9
·
1 Parent(s): 752d561

fix: remove [0] subscript from model output

Browse files
Files changed (1) hide show
  1. handler.py +5 -5
handler.py CHANGED
@@ -1,26 +1,26 @@
1
- from typing import Dict, Any
2
  from transformers import pipeline
3
  from PIL import Image
4
  import io, base64, requests
5
 
6
  class EndpointHandler:
7
- def __init__(self, model_dir: str, **kwargs: Any):
8
  print("🔹 Loading Swin2SR model ...")
9
  self.model = pipeline("image-to-image", model="sergeipetrov/swin2SR-classical-sr-x2-64")
10
 
11
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
12
  image_input = data.get("inputs")
13
  if not image_input:
14
  return {"error": "Missing 'inputs' field"}
15
 
16
- # Accept either a URL or a base64 string
17
  if isinstance(image_input, str) and image_input.startswith("http"):
18
  image = Image.open(requests.get(image_input, stream=True).raw).convert("RGB")
19
  else:
20
  image_bytes = base64.b64decode(image_input)
21
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
22
 
23
- result = self.model(image)[0]
 
24
 
25
  buf = io.BytesIO()
26
  result.save(buf, format="PNG")
 
 
1
  from transformers import pipeline
2
  from PIL import Image
3
  import io, base64, requests
4
 
5
  class EndpointHandler:
6
+ def __init__(self, model_dir: str = "", **kwargs):
7
  print("🔹 Loading Swin2SR model ...")
8
  self.model = pipeline("image-to-image", model="sergeipetrov/swin2SR-classical-sr-x2-64")
9
 
10
+ def __call__(self, data):
11
  image_input = data.get("inputs")
12
  if not image_input:
13
  return {"error": "Missing 'inputs' field"}
14
 
15
+ # Accept URL or base64
16
  if isinstance(image_input, str) and image_input.startswith("http"):
17
  image = Image.open(requests.get(image_input, stream=True).raw).convert("RGB")
18
  else:
19
  image_bytes = base64.b64decode(image_input)
20
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
21
 
22
+ # FIXED: remove [0]
23
+ result = self.model(image)
24
 
25
  buf = io.BytesIO()
26
  result.save(buf, format="PNG")