File size: 1,899 Bytes
03a019a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import os
import json
import base64
from io import BytesIO
import logging
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests
logging.basicConfig(level=logging.INFO)
model = None
processor = None
def init():
global model, processor
model_name = os.getenv("MODEL_NAME", "openai/clip-vit-base-patch32")
logging.info(f"Loading model: {model_name}")
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
logging.info("Model and processor loaded successfully.")
def handle_request(request_data, context):
results = []
for data in request_data:
try:
payload = json.loads(data)
image_input = payload.get("image")
text_input = payload.get("text", [])
if image_input.startswith("http://") or image_input.startswith("https://"):
response = requests.get(image_input, stream=True, timeout=10)
image = Image.open(response.raw).convert("RGB")
elif image_input.startswith("data:"):
header, encoded = image_input.split(",", 1)
image = Image.open(BytesIO(base64.b64decode(encoded))).convert("RGB")
else:
image = Image.open(BytesIO(base64.b64decode(image_input))).convert("RGB")
inputs = processor(text=text_input, images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
results.append(probs.tolist())
except Exception as e:
results.append({"error": str(e)})
return results
class EndpointHandler:
def __init__(self, model_dir=None):
init()
def handle(self, request_data, context):
return handle_request(request_data, context) |