|
|
|
|
|
import io, os, base64, requests, torch |
|
|
from PIL import Image |
|
|
from transformers import AutoModelForCausalLM, BitsAndBytesConfig |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
|
|
|
load_4bit = os.getenv("LOAD_IN_4BIT", "0") == "1" |
|
|
qcfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) if load_4bit else None |
|
|
|
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
path, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.bfloat16, |
|
|
quantization_config=qcfg, |
|
|
device_map="auto", |
|
|
multimodal_max_length=int(os.getenv("MULTIMODAL_MAX_LENGTH", "8192")), |
|
|
attn_implementation="eager", |
|
|
llm_attn_implementation="eager", |
|
|
) |
|
|
self.txt_tok = self.model.get_text_tokenizer() |
|
|
self.vis_tok = self.model.get_visual_tokenizer() |
|
|
|
|
|
def _load_image(self, spec): |
|
|
if "url" in spec: |
|
|
r = requests.get(spec["url"], timeout=10); r.raise_for_status() |
|
|
return Image.open(io.BytesIO(r.content)).convert("RGB") |
|
|
if "base64" in spec: |
|
|
return Image.open(io.BytesIO(base64.b64decode(spec["base64"]))).convert("RGB") |
|
|
raise ValueError("image must have 'url' or 'base64'") |
|
|
|
|
|
def __call__(self, data): |
|
|
prompt = data.get("prompt", "") |
|
|
imgs_spec = data.get("images", []) |
|
|
max_new = int(data.get("max_new_tokens", 512)) |
|
|
|
|
|
images = [self._load_image(s) for s in imgs_spec] |
|
|
if images: |
|
|
prefix = "\n".join(["<image>"] * len(images)) |
|
|
query = f"{prefix}\n{prompt}" |
|
|
max_part = 4 if len(images) > 1 else 9 |
|
|
else: |
|
|
query, max_part = prompt, None |
|
|
|
|
|
prompt, input_ids, pix = self.model.preprocess_inputs(query, images, max_partition=max_part) |
|
|
attn = (input_ids != self.txt_tok.pad_token_id).unsqueeze(0).to(self.model.device) |
|
|
input_ids = input_ids.unsqueeze(0).to(self.model.device) |
|
|
pix = [pix.to(dtype=self.vis_tok.dtype, device=self.vis_tok.device)] if pix is not None else None |
|
|
|
|
|
with torch.inference_mode(): |
|
|
out_ids = self.model.generate( |
|
|
input_ids, |
|
|
pixel_values=pix, |
|
|
attention_mask=attn, |
|
|
max_new_tokens=max_new, |
|
|
do_sample=False, |
|
|
use_cache=True, |
|
|
eos_token_id=self.model.generation_config.eos_token_id, |
|
|
pad_token_id=self.txt_tok.pad_token_id, |
|
|
)[0] |
|
|
text = self.txt_tok.decode(out_ids, skip_special_tokens=True) |
|
|
return {"output": text} |
|
|
|