mastari commited on
Commit
752d561
·
1 Parent(s): b99cb5c
Files changed (1) hide show
  1. handler.py +9 -12
handler.py CHANGED
@@ -1,31 +1,28 @@
 
1
  from transformers import pipeline
2
  from PIL import Image
3
  import io, base64, requests
4
 
5
  class EndpointHandler:
6
- def __init__(self, path=""):
7
- # This runs once when the container starts
8
- print("🔹 Loading Swin2SR model...")
9
  self.model = pipeline("image-to-image", model="sergeipetrov/swin2SR-classical-sr-x2-64")
10
 
11
- def __call__(self, data):
12
- # Handle input data
13
  image_input = data.get("inputs")
14
  if not image_input:
15
- return {"error": "No 'inputs' field provided."}
16
 
17
- # Accept both image URLs and base64 strings
18
  if isinstance(image_input, str) and image_input.startswith("http"):
19
  image = Image.open(requests.get(image_input, stream=True).raw).convert("RGB")
20
  else:
21
  image_bytes = base64.b64decode(image_input)
22
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
23
 
24
- # Run inference
25
  result = self.model(image)[0]
26
 
27
- # Encode to base64 for response
28
- buffer = io.BytesIO()
29
- result.save(buffer, format="PNG")
30
- return {"image_base64": base64.b64encode(buffer.getvalue()).decode("utf-8")}
31
 
 
1
+ from typing import Dict, Any
2
  from transformers import pipeline
3
  from PIL import Image
4
  import io, base64, requests
5
 
6
  class EndpointHandler:
7
+ def __init__(self, model_dir: str, **kwargs: Any):
8
+ print("🔹 Loading Swin2SR model ...")
 
9
  self.model = pipeline("image-to-image", model="sergeipetrov/swin2SR-classical-sr-x2-64")
10
 
11
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
12
  image_input = data.get("inputs")
13
  if not image_input:
14
+ return {"error": "Missing 'inputs' field"}
15
 
16
+ # Accept either a URL or a base64 string
17
  if isinstance(image_input, str) and image_input.startswith("http"):
18
  image = Image.open(requests.get(image_input, stream=True).raw).convert("RGB")
19
  else:
20
  image_bytes = base64.b64decode(image_input)
21
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
22
 
 
23
  result = self.model(image)[0]
24
 
25
+ buf = io.BytesIO()
26
+ result.save(buf, format="PNG")
27
+ return {"image_base64": base64.b64encode(buf.getvalue()).decode("utf-8")}
 
28