| | import time |
| | import torch |
| | from transformers import CLIPProcessor, CLIPModel |
| | from PIL import Image |
| | import requests |
| | import io |
| |
|
| | class EndpointHandler: |
| | def __init__(self, model_dir=None): |
| | print("π Loading model...") |
| | self.model = CLIPModel.from_pretrained("dazpye/clip-image") |
| | self.processor = CLIPProcessor.from_pretrained("dazpye/clip-image") |
| |
|
| | def _load_image(self, image_url): |
| | """Fetches an image from a URL.""" |
| | try: |
| | print(f"π Fetching image from: {image_url}") |
| | response = requests.get(image_url, timeout=5) |
| | response.raise_for_status() |
| | return Image.open(io.BytesIO(response.content)).convert("RGB") |
| | except Exception as e: |
| | print(f"β Image loading failed: {e}") |
| | return None |
| |
|
| | def __call__(self, data): |
| | """Processes input and runs inference.""" |
| | start_time = time.time() |
| |
|
| | print("π₯ Processing input...") |
| |
|
| | if "inputs" in data: |
| | data = data["inputs"] |
| |
|
| | text = data.get("text", ["default text"]) |
| | image_urls = data.get("images", []) |
| |
|
| | images = [self._load_image(url) for url in image_urls if url] |
| | images = [img for img in images if img] |
| |
|
| | if not images: |
| | return {"error": "No valid images provided."} |
| |
|
| | |
| | inputs = self.processor( |
| | text=text, |
| | images=images, |
| | return_tensors="pt", |
| | padding=True, |
| | truncation=True |
| | ) |
| |
|
| | print("π₯οΈ Running inference...") |
| | with torch.no_grad(): |
| | outputs = self.model(**inputs) |
| |
|
| | |
| | logits_per_image = outputs.logits_per_image |
| | probabilities = logits_per_image.softmax(dim=1) |
| |
|
| | |
| | predictions = [] |
| | for i, probs in enumerate(probabilities): |
| | sorted_indices = torch.argsort(probs, descending=True) |
| | best_matches = [(text[idx], probs[idx].item()) for idx in sorted_indices[:3]] |
| | predictions.append({"image_index": i, "top_matches": best_matches}) |
| |
|
| | total_time = time.time() - start_time |
| |
|
| | return { |
| | "predictions": predictions, |
| | "processing_time_seconds": round(total_time, 4) |
| | } |