mountainsma commited on
Commit
3e0528a
·
verified ·
1 Parent(s): 7faf006

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +56 -0
handler.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ from typing import Any, Dict
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import CLIPModel, CLIPProcessor
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, model_dir: str = "", **kwargs: Any):
11
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ self.model = CLIPModel.from_pretrained(model_dir).to(self.device)
13
+ self.processor = CLIPProcessor.from_pretrained(model_dir)
14
+ self.model.eval()
15
+
16
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
17
+ # Flexibel: data kan direct inputs bevatten of via "inputs"
18
+ payload = data.get("inputs", data)
19
+
20
+ text = payload.get("text")
21
+ image_input = payload.get("image") # base64 string of URL
22
+
23
+ # Image verwerken (base64 of PIL)
24
+ image = None
25
+ if isinstance(image_input, str):
26
+ if image_input.startswith(("http://", "https://")):
27
+ # URL-support (optioneel, vereist requests + PIL)
28
+ import requests
29
+ response = requests.get(image_input)
30
+ image = Image.open(io.BytesIO(response.content)).convert("RGB")
31
+ else:
32
+ # base64
33
+ image_bytes = base64.b64decode(image_input.split(",")[-1])
34
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
35
+
36
+ with torch.no_grad():
37
+ if image is not None and text is not None:
38
+ # Beide → image + text embeddings
39
+ inputs = self.processor(text=text, images=image, return_tensors="pt", padding=True).to(self.device)
40
+ outputs = self.model(**inputs)
41
+ return {
42
+ "image_embedding": outputs.image_embeds[0].cpu().tolist(),
43
+ "text_embedding": outputs.text_embeds[0].cpu().tolist()
44
+ }
45
+ elif image is not None:
46
+ # Alleen image
47
+ inputs = self.processor(images=image, return_tensors="pt").to(self.device)
48
+ image_features = self.model.get_image_features(**inputs)
49
+ return {"image_embedding": image_features[0].cpu().tolist()}
50
+ elif text is not None:
51
+ # Alleen text
52
+ inputs = self.processor(text=text, return_tensors="pt", padding=True).to(self.device)
53
+ text_features = self.model.get_text_features(**inputs)
54
+ return {"text_embedding": text_features[0].cpu().tolist()}
55
+ else:
56
+ return {"error": "Geef 'text' of 'image' (base64) mee"}