import os import json import base64 from io import BytesIO import logging from transformers import CLIPProcessor, CLIPModel from PIL import Image import requests logging.basicConfig(level=logging.INFO) model = None processor = None def init(): global model, processor model_name = os.getenv("MODEL_NAME", "openai/clip-vit-base-patch32") logging.info(f"Loading model: {model_name}") model = CLIPModel.from_pretrained(model_name) processor = CLIPProcessor.from_pretrained(model_name) logging.info("Model and processor loaded successfully.") def handle_request(request_data, context): results = [] for data in request_data: try: payload = json.loads(data) image_input = payload.get("image") text_input = payload.get("text", []) if image_input.startswith("http://") or image_input.startswith("https://"): response = requests.get(image_input, stream=True, timeout=10) image = Image.open(response.raw).convert("RGB") elif image_input.startswith("data:"): header, encoded = image_input.split(",", 1) image = Image.open(BytesIO(base64.b64decode(encoded))).convert("RGB") else: image = Image.open(BytesIO(base64.b64decode(image_input))).convert("RGB") inputs = processor(text=text_input, images=image, return_tensors="pt", padding=True) outputs = model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1) results.append(probs.tolist()) except Exception as e: results.append({"error": str(e)}) return results class EndpointHandler: def __init__(self, model_dir=None): init() def handle(self, request_data, context): return handle_request(request_data, context)