| | import torch |
| | from transformers import AutoProcessor, AutoModelForVision2Seq, GenerationConfig |
| | from transformers.image_utils import load_image |
| |
|
| | from typing import Any, Dict |
| |
|
| | import base64 |
| | import re |
| | from copy import deepcopy |
| |
|
| |
|
| | def is_base64(s: str) -> bool: |
| | try: |
| | return base64.b64encode(base64.b64decode(s)).decode() == s |
| | except Exception: |
| | return False |
| |
|
| |
|
| | def is_url(s: str) -> bool: |
| | url_pattern = re.compile(r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+") |
| | return bool(url_pattern.match(s)) |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__( |
| | self, |
| | model_dir: str = "HuggingFaceTB/SmolVLM-Instruct", |
| | **kwargs: Any, |
| | ) -> None: |
| | self.processor = AutoProcessor.from_pretrained(model_dir) |
| | self.model = AutoModelForVision2Seq.from_pretrained( |
| | model_dir, |
| | torch_dtype=torch.bfloat16, |
| | _attn_implementation="eager", |
| | device_map="auto", |
| | ).eval() |
| | self.generation_config = GenerationConfig.from_pretrained(model_dir) |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Any: |
| | if "inputs" not in data: |
| | raise ValueError( |
| | "The request body must contain a key 'inputs' with a list of inputs." |
| | ) |
| |
|
| | if not isinstance(data["inputs"], list): |
| | raise ValueError( |
| | "The request inputs must be a list of dictionaries with the keys 'text' and 'images', being a" |
| | " string with the prompt and a list with the image URLs or base64 encodings, respectively; and" |
| | " optionally including the key 'generation_parameters' key too." |
| | ) |
| |
|
| | predictions = [] |
| | for input in data["inputs"]: |
| | if "text" not in input: |
| | raise ValueError( |
| | "The request input body must contain the key 'text' with the prompt to use." |
| | ) |
| |
|
| | if "images" not in input or ( |
| | not isinstance(input["images"], list) |
| | and all(isinstance(i, str) for i in input["images"]) |
| | ): |
| | raise ValueError( |
| | "The request input body must contain the key 'images' with a list of strings," |
| | " where each string corresponds to an image on either base64 encoding, or provided" |
| | " as a valid URL (needs to be publicly accessible and contain a valid image)." |
| | ) |
| |
|
| | images = [] |
| | for image in input["images"]: |
| | try: |
| | images.append(load_image(image)) |
| | except Exception as e: |
| | raise ValueError( |
| | f"Provided {image=} is not valid, please make sure that's either a base64 encoding" |
| | f" of a valid image, or a publicly accesible URL to a valid image.\nFailed with {e=}." |
| | ) |
| |
|
| | generation_config = deepcopy(self.generation_config) |
| | generation_config.update(**input.get("generation_parameters", {"max_new_tokens": 128})) |
| |
|
| | messages = [ |
| | { |
| | "role": "user", |
| | "content": [{"type": "image"} for _ in images] |
| | + [{"type": "text", "text": input["text"]}], |
| | }, |
| | ] |
| | prompt = self.processor.apply_chat_template( |
| | messages, add_generation_prompt=True |
| | ) |
| | processed_inputs = self.processor( |
| | text=prompt, images=images, return_tensors="pt" |
| | ).to(self.model.device) |
| |
|
| | generated_ids = self.model.generate( |
| | **processed_inputs, generation_config=generation_config |
| | ) |
| | generated_texts = self.processor.batch_decode( |
| | generated_ids, |
| | skip_special_tokens=True, |
| | ) |
| | predictions.append(generated_texts[0]) |
| |
|
| | return {"predictions": predictions} |