| | from typing import Any, Dict, List |
| |
|
| | import requests |
| | import torch |
| |
|
| | from transformers import AutoProcessor, PaliGemmaForConditionalGeneration |
| | from PIL import Image |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__( |
| | self, |
| | model_dir: str = "/opt/huggingface/model", |
| | **kwargs: Any, |
| | ) -> None: |
| | self.model = PaliGemmaForConditionalGeneration.from_pretrained( |
| | "google/paligemma-3b-mix-448", |
| | revision="bfloat16", |
| | torch_dtype=torch.bfloat16, |
| | low_cpu_mem_usage=True, |
| | device_map="auto", |
| | ).eval() |
| |
|
| | self.processor = AutoProcessor.from_pretrained("google/paligemma-3b-mix-448") |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, List[Any]]: |
| | if "instances" not in data: |
| | raise ValueError( |
| | "The request body must contain a key `instances` with a list of instances." |
| | ) |
| |
|
| | predictions = [] |
| | for input in data["instances"]: |
| | if "prompt" in input: |
| | input["text"] = input.pop("prompt") |
| |
|
| | if any(key not in input for key in {"text", "image_url"}): |
| | raise ValueError( |
| | "The request body for each instance should contain both the `text` and the `image_url` key with a valid image URL." |
| | ) |
| |
|
| | try: |
| | image = Image.open(requests.get(input["image_url"], stream=True).raw) |
| | except Exception as e: |
| | raise ValueError( |
| | f"The provided image URL ({input['image_url']}) cannot be downloaded (with exception {e}), make sure it's public and accessible." |
| | ) |
| |
|
| | inputs = self.processor( |
| | text=input["text"], images=image, return_tensors="pt" |
| | ).to(self.model.device) |
| | input_len = inputs["input_ids"].shape[-1] |
| |
|
| | with torch.inference_mode(): |
| | generation_kwargs = data.get( |
| | "generation_kwargs", {"max_new_tokens": 100, "do_sample": False} |
| | ) |
| | generation = self.model.generate(**inputs, **generation_kwargs) |
| | generation = generation[0][input_len:] |
| | response = self.processor.decode(generation, skip_special_tokens=True) |
| | predictions.append(response) |
| | return {"predictions": predictions} |
| |
|