Rapid_ECG / handler.py
CanerDedeoglu's picture
Update handler.py
77b64f3 verified
raw
history blame
6.97 kB
# /repository/handler.py
import os, io, base64
from typing import Any, Dict, List, Optional
import torch
from PIL import Image
# --- LLaVA (demo) parçaları ---
from llava.model.builder import load_pretrained_model, get_model_name_from_path
from llava.mm_utils import tokenizer_image_token, process_images
from llava.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
)
from llava.conversation import conv_templates
from llava.utils import disable_torch_init
# =========================
# Ortam / Varsayılanlar
# =========================
# 1) Yerelden yüklemek için (bu repository içi): boş bırakın veya HF_MODEL_LOCAL_DIR=/repository
HF_MODEL_LOCAL_DIR = os.getenv("HF_MODEL_LOCAL_DIR", "").strip()
# 2) Hub'dan yüklemek isterseniz: HF_MODEL_ID=org/name
HF_MODEL_ID = os.getenv("HF_MODEL_ID", "").strip()
# Demo ile aynı conv_mode
DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v2")
# Güvenli varsayılan (çok büyük tutmayalım)
MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
# Flash-Attention zorunluluğunu kaldır, SDPA kullan
os.environ.setdefault("ATTN_IMPLEMENTATION", "sdpa")
class EndpointHandler:
"""
Hugging Face Inference Toolkit tarafından çağrılan handler.
Girdi şeması (demo ile uyumlu):
{
"inputs": { "query": "...", "image": "<url|dataurl|path>" },
"parameters": {
"max_new_tokens": 256, "temperature": 0.0, "top_p": 1.0,
"repetition_penalty": 1.0, "do_sample": false, "use_cache": true
},
"conv_mode": "llava_v2" # opsiyonel
}
Dönüş:
[ { "generated_text": "..." } ]
"""
def __init__(self, path: str = "") -> None:
# path -> /repository
disable_torch_init()
# Modelin yüklenme yolu seçimi
if HF_MODEL_LOCAL_DIR:
model_path = HF_MODEL_LOCAL_DIR
elif HF_MODEL_ID:
model_path = HF_MODEL_ID
else:
# Ağırlıklar bu repoda ise
model_path = path
# Model adı (LLaVA yardımcı)
self.model_name = get_model_name_from_path(model_path)
# LLaVA yüklemesi (demo ile aynı giriş noktası)
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path=model_path,
model_base=None, # LoRA yoksa None
model_name=self.model_name,
torch_dtype="auto",
attn_implementation=os.getenv("ATTN_IMPLEMENTATION", "sdpa"),
device_map="auto",
)
self.model.eval()
# Görsel token işaretleri (model config'ine bağlı)
self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
self.image_token = DEFAULT_IMAGE_TOKEN
self.im_start = DEFAULT_IM_START_TOKEN
self.im_end = DEFAULT_IM_END_TOKEN
# ---------------------------
# Yardımcılar
# ---------------------------
def _load_image(self, img_field: str) -> Optional[Image.Image]:
"""URL / data URL / yerel path -> PIL.Image"""
if not img_field:
return None
try:
if img_field.startswith("data:image"):
head, b64 = img_field.split(",", 1)
return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
elif img_field.startswith("http://") or img_field.startswith("https://"):
import requests
r = requests.get(img_field, timeout=20)
r.raise_for_status()
return Image.open(io.BytesIO(r.content)).convert("RGB")
else:
return Image.open(img_field).convert("RGB")
except Exception as e:
# Görsel okunamadıysa açıklayıcı hata bırak
raise RuntimeError(f"Image load failed: {e}") from e
def _build_prompt(self, user_text: str, conv_mode: str) -> str:
"""Demodaki gibi conv_templates ile diyalog şablonu kur."""
# Yanlış conv_mode gelirse default'a düş
if conv_mode not in conv_templates:
conv_mode = DEFAULT_CONV_MODE
conv = conv_templates[conv_mode].copy()
if self.use_im_start_end:
content = f"{self.im_start}{self.image_token}{self.im_end}\n{user_text}"
else:
content = f"{self.image_token}\n{user_text}"
conv.append_message(conv.roles[0], content) # user
conv.append_message(conv.roles[1], None) # assistant (boş)
return conv.get_prompt()
# ---------------------------
# Inference giriş noktası
# ---------------------------
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.get("inputs") or {}
params = data.get("parameters") or {}
conv_mode_req = data.get("conv_mode")
conv_mode = conv_mode_req if conv_mode_req in conv_templates else DEFAULT_CONV_MODE
query_text = inputs.get("query", "")
image_f = inputs.get("image", "")
pil_img = self._load_image(image_f) if image_f else None
# 1) Prompt hazırla
prompt = self._build_prompt(query_text, conv_mode)
# 2) Görsel tensörü
image_tensors = None
if pil_img is not None:
image_tensors = process_images([pil_img], self.image_processor, self.model.config)
# 3) Tokenize (görüntü tokenını göm)
input_ids = tokenizer_image_token(
prompt,
self.tokenizer,
IMAGE_TOKEN_INDEX,
return_tensors="pt",
)
input_ids = input_ids.to(self.model.device, non_blocking=True)
# 4) context_len'e göre güvenli max_new_tokens
requested_max_new = int(params.get("max_new_tokens", MAX_NEW_TOKENS_DEF))
# ufak tampon ile aşımı engelle
avail = max(16, int(self.context_len) - int(input_ids.shape[-1]) - 8)
max_new_tokens = max(1, min(requested_max_new, avail))
# Görseli cihaza taşı
if image_tensors is not None:
image_tensors = image_tensors.to(self.model.device, dtype=self.model.dtype, non_blocking=True)
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"temperature": float(params.get("temperature", 0.0)),
"top_p": float(params.get("top_p", 1.0)),
"repetition_penalty": float(params.get("repetition_penalty", 1.0)),
"do_sample": bool(params.get("do_sample", float(params.get("temperature", 0.0)) > 0)),
"use_cache": bool(params.get("use_cache", True)),
}
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=image_tensors,
**gen_kwargs,
)
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
return [{"generated_text": outputs}]