| |
| import os, io, base64 |
| from typing import Any, Dict, List, Optional |
|
|
| import torch |
| from PIL import Image |
|
|
| |
| from llava.model.builder import load_pretrained_model, get_model_name_from_path |
| from llava.mm_utils import tokenizer_image_token, process_images |
| from llava.constants import ( |
| IMAGE_TOKEN_INDEX, |
| DEFAULT_IMAGE_TOKEN, |
| DEFAULT_IM_START_TOKEN, |
| DEFAULT_IM_END_TOKEN, |
| ) |
| from llava.conversation import conv_templates |
| from llava.utils import disable_torch_init |
|
|
|
|
| |
| |
| |
| |
| HF_MODEL_LOCAL_DIR = os.getenv("HF_MODEL_LOCAL_DIR", "").strip() |
| |
| HF_MODEL_ID = os.getenv("HF_MODEL_ID", "").strip() |
|
|
| |
| DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v2") |
| |
| MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "256")) |
|
|
| |
| os.environ.setdefault("ATTN_IMPLEMENTATION", "sdpa") |
|
|
|
|
| class EndpointHandler: |
| """ |
| Hugging Face Inference Toolkit tarafından çağrılan handler. |
| Girdi şeması (demo ile uyumlu): |
| { |
| "inputs": { "query": "...", "image": "<url|dataurl|path>" }, |
| "parameters": { |
| "max_new_tokens": 256, "temperature": 0.0, "top_p": 1.0, |
| "repetition_penalty": 1.0, "do_sample": false, "use_cache": true |
| }, |
| "conv_mode": "llava_v2" # opsiyonel |
| } |
| Dönüş: |
| [ { "generated_text": "..." } ] |
| """ |
|
|
| def __init__(self, path: str = "") -> None: |
| |
| disable_torch_init() |
|
|
| |
| if HF_MODEL_LOCAL_DIR: |
| model_path = HF_MODEL_LOCAL_DIR |
| elif HF_MODEL_ID: |
| model_path = HF_MODEL_ID |
| else: |
| |
| model_path = path |
|
|
| |
| self.model_name = get_model_name_from_path(model_path) |
|
|
| |
| self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( |
| model_path=model_path, |
| model_base=None, |
| model_name=self.model_name, |
| torch_dtype="auto", |
| attn_implementation=os.getenv("ATTN_IMPLEMENTATION", "sdpa"), |
| device_map="auto", |
| ) |
| self.model.eval() |
|
|
| |
| self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False) |
| self.image_token = DEFAULT_IMAGE_TOKEN |
| self.im_start = DEFAULT_IM_START_TOKEN |
| self.im_end = DEFAULT_IM_END_TOKEN |
|
|
| |
| |
| |
| def _load_image(self, img_field: str) -> Optional[Image.Image]: |
| """URL / data URL / yerel path -> PIL.Image""" |
| if not img_field: |
| return None |
| try: |
| if img_field.startswith("data:image"): |
| head, b64 = img_field.split(",", 1) |
| return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB") |
| elif img_field.startswith("http://") or img_field.startswith("https://"): |
| import requests |
| r = requests.get(img_field, timeout=20) |
| r.raise_for_status() |
| return Image.open(io.BytesIO(r.content)).convert("RGB") |
| else: |
| return Image.open(img_field).convert("RGB") |
| except Exception as e: |
| |
| raise RuntimeError(f"Image load failed: {e}") from e |
|
|
| def _build_prompt(self, user_text: str, conv_mode: str) -> str: |
| """Demodaki gibi conv_templates ile diyalog şablonu kur.""" |
| |
| if conv_mode not in conv_templates: |
| conv_mode = DEFAULT_CONV_MODE |
|
|
| conv = conv_templates[conv_mode].copy() |
| if self.use_im_start_end: |
| content = f"{self.im_start}{self.image_token}{self.im_end}\n{user_text}" |
| else: |
| content = f"{self.image_token}\n{user_text}" |
|
|
| conv.append_message(conv.roles[0], content) |
| conv.append_message(conv.roles[1], None) |
| return conv.get_prompt() |
|
|
| |
| |
| |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| inputs = data.get("inputs") or {} |
| params = data.get("parameters") or {} |
| conv_mode_req = data.get("conv_mode") |
| conv_mode = conv_mode_req if conv_mode_req in conv_templates else DEFAULT_CONV_MODE |
|
|
| query_text = inputs.get("query", "") |
| image_f = inputs.get("image", "") |
| pil_img = self._load_image(image_f) if image_f else None |
|
|
| |
| prompt = self._build_prompt(query_text, conv_mode) |
|
|
| |
| image_tensors = None |
| if pil_img is not None: |
| image_tensors = process_images([pil_img], self.image_processor, self.model.config) |
|
|
| |
| input_ids = tokenizer_image_token( |
| prompt, |
| self.tokenizer, |
| IMAGE_TOKEN_INDEX, |
| return_tensors="pt", |
| ) |
| input_ids = input_ids.to(self.model.device, non_blocking=True) |
|
|
| |
| requested_max_new = int(params.get("max_new_tokens", MAX_NEW_TOKENS_DEF)) |
| |
| avail = max(16, int(self.context_len) - int(input_ids.shape[-1]) - 8) |
| max_new_tokens = max(1, min(requested_max_new, avail)) |
|
|
| |
| if image_tensors is not None: |
| image_tensors = image_tensors.to(self.model.device, dtype=self.model.dtype, non_blocking=True) |
|
|
| gen_kwargs = { |
| "max_new_tokens": max_new_tokens, |
| "temperature": float(params.get("temperature", 0.0)), |
| "top_p": float(params.get("top_p", 1.0)), |
| "repetition_penalty": float(params.get("repetition_penalty", 1.0)), |
| "do_sample": bool(params.get("do_sample", float(params.get("temperature", 0.0)) > 0)), |
| "use_cache": bool(params.get("use_cache", True)), |
| } |
|
|
| with torch.inference_mode(): |
| output_ids = self.model.generate( |
| input_ids, |
| images=image_tensors, |
| **gen_kwargs, |
| ) |
|
|
| outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() |
| return [{"generated_text": outputs}] |
|
|