| from transformers import AutoModelForCausalLM, AutoProcessor | |
| from PIL import Image | |
| import requests | |
| import torch | |
| import io | |
| class EndpointHandler: | |
| def __init__(self, model_dir): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True).to(device) | |
| self.processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True) | |
| self.device = device | |
| def __call__(self, data): | |
| try: | |
| url = data.get("inputs", {}).get("url") | |
| if not url: | |
| return {"error": "Missing URL"} | |
| headers = { | |
| "User-Agent": "Mozilla/5.0", | |
| "Accept": "image/*" | |
| } | |
| response = requests.get(url, headers=headers, verify=False) | |
| response.raise_for_status() | |
| image_data = io.BytesIO(response.content) | |
| image = Image.open(image_data).convert("RGB") | |
| inputs = self.processor( | |
| text="<MORE_DETAILED_CAPTION>", | |
| images=image, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| with torch.inference_mode(): | |
| output = self.model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| num_beams=3 | |
| ) | |
| text = self.processor.batch_decode(output, skip_special_tokens=True)[0] | |
| return {"caption": text} | |
| except Exception as e: | |
| return {"error": str(e)} |