| import os | |
| import json | |
| import base64 | |
| from io import BytesIO | |
| import logging | |
| from transformers import CLIPProcessor, CLIPModel | |
| from PIL import Image | |
| import requests | |
| logging.basicConfig(level=logging.INFO) | |
| model = None | |
| processor = None | |
| def init(): | |
| global model, processor | |
| model_name = os.getenv("MODEL_NAME", "openai/clip-vit-base-patch32") | |
| logging.info(f"Loading model: {model_name}") | |
| model = CLIPModel.from_pretrained(model_name) | |
| processor = CLIPProcessor.from_pretrained(model_name) | |
| logging.info("Model and processor loaded successfully.") | |
| def handle_request(request_data, context): | |
| results = [] | |
| for data in request_data: | |
| try: | |
| payload = json.loads(data) | |
| image_input = payload.get("image") | |
| text_input = payload.get("text", []) | |
| if image_input.startswith("http://") or image_input.startswith("https://"): | |
| response = requests.get(image_input, stream=True, timeout=10) | |
| image = Image.open(response.raw).convert("RGB") | |
| elif image_input.startswith("data:"): | |
| header, encoded = image_input.split(",", 1) | |
| image = Image.open(BytesIO(base64.b64decode(encoded))).convert("RGB") | |
| else: | |
| image = Image.open(BytesIO(base64.b64decode(image_input))).convert("RGB") | |
| inputs = processor(text=text_input, images=image, return_tensors="pt", padding=True) | |
| outputs = model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| probs = logits_per_image.softmax(dim=1) | |
| results.append(probs.tolist()) | |
| except Exception as e: | |
| results.append({"error": str(e)}) | |
| return results | |
| class EndpointHandler: | |
| def __init__(self, model_dir=None): | |
| init() | |
| def handle(self, request_data, context): | |
| return handle_request(request_data, context) |