Hsquad commited on
Commit
03a019a
·
verified ·
1 Parent(s): 100bfc8

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +53 -0
handler.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ from io import BytesIO
5
+ import logging
6
+
7
+ from transformers import CLIPProcessor, CLIPModel
8
+ from PIL import Image
9
+ import requests
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+
13
+ model = None
14
+ processor = None
15
+
16
+ def init():
17
+ global model, processor
18
+ model_name = os.getenv("MODEL_NAME", "openai/clip-vit-base-patch32")
19
+ logging.info(f"Loading model: {model_name}")
20
+ model = CLIPModel.from_pretrained(model_name)
21
+ processor = CLIPProcessor.from_pretrained(model_name)
22
+ logging.info("Model and processor loaded successfully.")
23
+
24
+ def handle_request(request_data, context):
25
+ results = []
26
+ for data in request_data:
27
+ try:
28
+ payload = json.loads(data)
29
+ image_input = payload.get("image")
30
+ text_input = payload.get("text", [])
31
+ if image_input.startswith("http://") or image_input.startswith("https://"):
32
+ response = requests.get(image_input, stream=True, timeout=10)
33
+ image = Image.open(response.raw).convert("RGB")
34
+ elif image_input.startswith("data:"):
35
+ header, encoded = image_input.split(",", 1)
36
+ image = Image.open(BytesIO(base64.b64decode(encoded))).convert("RGB")
37
+ else:
38
+ image = Image.open(BytesIO(base64.b64decode(image_input))).convert("RGB")
39
+ inputs = processor(text=text_input, images=image, return_tensors="pt", padding=True)
40
+ outputs = model(**inputs)
41
+ logits_per_image = outputs.logits_per_image
42
+ probs = logits_per_image.softmax(dim=1)
43
+ results.append(probs.tolist())
44
+ except Exception as e:
45
+ results.append({"error": str(e)})
46
+ return results
47
+
48
+ class EndpointHandler:
49
+ def __init__(self, model_dir=None):
50
+ init()
51
+
52
+ def handle(self, request_data, context):
53
+ return handle_request(request_data, context)