Ovis2-1B / handler.py
gallionlabs's picture
Update handler.py
d980e56 verified
# handler.py
import io, os, base64, requests, torch
from PIL import Image
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
class EndpointHandler:
def __init__(self, path=""):
# Optionnel : quantif 4-bit via variable d'env
load_4bit = os.getenv("LOAD_IN_4BIT", "0") == "1"
qcfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) if load_4bit else None
# ⚠️ Utiliser le path local fourni par l'endpoint
# ⚠️ Forcer l'implémentation d'attention "eager" (pas SDPA, pas FlashAttention)
self.model = AutoModelForCausalLM.from_pretrained(
path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
quantization_config=qcfg,
device_map="auto",
multimodal_max_length=int(os.getenv("MULTIMODAL_MAX_LENGTH", "8192")),
attn_implementation="eager",
llm_attn_implementation="eager", # ce champ est lu par certaines archis VLM
)
self.txt_tok = self.model.get_text_tokenizer()
self.vis_tok = self.model.get_visual_tokenizer()
def _load_image(self, spec):
if "url" in spec:
r = requests.get(spec["url"], timeout=10); r.raise_for_status()
return Image.open(io.BytesIO(r.content)).convert("RGB")
if "base64" in spec:
return Image.open(io.BytesIO(base64.b64decode(spec["base64"]))).convert("RGB")
raise ValueError("image must have 'url' or 'base64'")
def __call__(self, data):
prompt = data.get("prompt", "")
imgs_spec = data.get("images", [])
max_new = int(data.get("max_new_tokens", 512))
images = [self._load_image(s) for s in imgs_spec]
if images:
prefix = "\n".join(["<image>"] * len(images))
query = f"{prefix}\n{prompt}"
max_part = 4 if len(images) > 1 else 9
else:
query, max_part = prompt, None
prompt, input_ids, pix = self.model.preprocess_inputs(query, images, max_partition=max_part)
attn = (input_ids != self.txt_tok.pad_token_id).unsqueeze(0).to(self.model.device)
input_ids = input_ids.unsqueeze(0).to(self.model.device)
pix = [pix.to(dtype=self.vis_tok.dtype, device=self.vis_tok.device)] if pix is not None else None
with torch.inference_mode():
out_ids = self.model.generate(
input_ids,
pixel_values=pix,
attention_mask=attn,
max_new_tokens=max_new,
do_sample=False,
use_cache=True,
eos_token_id=self.model.generation_config.eos_token_id,
pad_token_id=self.txt_tok.pad_token_id,
)[0]
text = self.txt_tok.decode(out_ids, skip_special_tokens=True)
return {"output": text}