mastari commited on
Commit
b99cb5c
·
1 Parent(s): f647f0a

fix everything

Browse files
Files changed (1) hide show
  1. handler.py +26 -29
handler.py CHANGED
@@ -2,33 +2,30 @@ from transformers import pipeline
2
  from PIL import Image
3
  import io, base64, requests
4
 
5
- # Load model once when the container starts
6
- def load_model():
7
- print("Loading Swin2SR model...")
8
- return pipeline("image-to-image", model="sergeipetrov/swin2SR-classical-sr-x2-64")
9
-
10
- model = load_model()
11
-
12
- # Hugging Face Inference Endpoint entrypoint
13
- def __call__(data):
14
- # data can be dict with "inputs" (URL or base64)
15
- image_input = data.get("inputs")
16
-
17
- # If it's a URL
18
- if isinstance(image_input, str) and image_input.startswith("http"):
19
- image = Image.open(requests.get(image_input, stream=True).raw).convert("RGB")
20
- else:
21
- # assume it's base64
22
- image_bytes = base64.b64decode(image_input)
23
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
24
-
25
- # Run inference
26
- result = model(image)[0]
27
-
28
- # Convert to base64 for API response
29
- buffered = io.BytesIO()
30
- result.save(buffered, format="PNG")
31
- encoded_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
32
-
33
- return {"image_base64": encoded_image}
34
 
 
2
  from PIL import Image
3
  import io, base64, requests
4
 
5
+ class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ # This runs once when the container starts
8
+ print("🔹 Loading Swin2SR model...")
9
+ self.model = pipeline("image-to-image", model="sergeipetrov/swin2SR-classical-sr-x2-64")
10
+
11
+ def __call__(self, data):
12
+ # Handle input data
13
+ image_input = data.get("inputs")
14
+ if not image_input:
15
+ return {"error": "No 'inputs' field provided."}
16
+
17
+ # Accept both image URLs and base64 strings
18
+ if isinstance(image_input, str) and image_input.startswith("http"):
19
+ image = Image.open(requests.get(image_input, stream=True).raw).convert("RGB")
20
+ else:
21
+ image_bytes = base64.b64decode(image_input)
22
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
23
+
24
+ # Run inference
25
+ result = self.model(image)[0]
26
+
27
+ # Encode to base64 for response
28
+ buffer = io.BytesIO()
29
+ result.save(buffer, format="PNG")
30
+ return {"image_base64": base64.b64encode(buffer.getvalue()).decode("utf-8")}
 
 
 
31