Rapid_ECG / handler.py
CanerDedeoglu's picture
Update handler.py
20984f8 verified
raw
history blame
4 kB
import base64, io, os, logging
import requests, torch, transformers
from PIL import Image
from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoProcessor
class EndpointHandler:
def __init__(self, path: str = ""):
logging.warning(f"[INIT] Transformers version: {transformers.__version__}")
self.model_id = os.getenv("PULSE_MODEL_ID", "PULSE-ECG/PULSE-7B")
# 1) Normal yol: pipeline
try:
self.pipe = pipeline(
task="image-text-to-text",
model=self.model_id,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True,
)
logging.warning("[INIT] pipeline() loaded OK")
self._ensure_pad_token()
return
except Exception as e:
logging.warning(f"[INIT] pipeline() failed: {e}")
# 2) llava_llama -> llava config override
try:
cfg = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
if getattr(cfg, "model_type", None) == "llava_llama":
logging.warning("[INIT] llava_llama -> llava override")
cfg.model_type = "llava"
self.pipe = pipeline(
task="image-text-to-text",
model=self.model_id,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True,
config=cfg,
)
logging.warning("[INIT] pipeline() loaded with config override")
self._ensure_pad_token()
return
except Exception as e:
logging.warning(f"[INIT] override failed: {e}")
# 3) Fallback: AutoProcessor + AutoModel
logging.warning("[INIT] Fallback: AutoProcessor/AutoModel")
proc = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True)
mdl = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True
)
def _mini_pipe(msgs, **params):
inputs = proc(msgs, return_tensors="pt").to(mdl.device)
gen_kwargs = {"max_new_tokens": 512, **params}
with torch.inference_mode():
out_ids = mdl.generate(**inputs, **gen_kwargs)
return proc.tokenizer.batch_decode(out_ids, skip_special_tokens=True)
self.pipe = _mini_pipe
logging.warning("[INIT] Fallback loaded")
# ---- helpers ----
def _ensure_pad_token(self):
try:
if hasattr(self.pipe, "model"):
gen_cfg = getattr(self.pipe.model, "generation_config", None)
if gen_cfg and getattr(gen_cfg, "pad_token_id", None) is None:
self.pipe.model.generation_config.pad_token_id = self.pipe.model.config.eos_token_id
except Exception:
pass
def _normalize_inputs(self, data: dict):
# Basit şema
if "image_url" in data or "text" in data:
image_url = data.get("image_url")
text = data.get("text", "Interpret this ECG image.")
if not image_url:
raise ValueError("No image_url provided")
return [
{"role": "user", "content": [
{"type": "image", "image_url": image_url},
{"type": "text", "text": text},
]}
], data.get("parameters", {})
# Multimodal chat şeması
if "inputs" in data:
return data.get("inputs", []), data.get("parameters", {})
raise ValueError("Invalid payload: expected 'image_url'+'text' or 'inputs'.")
def __call__(self, data: dict):
msgs, params = self._normalize_inputs(data)
params = {"max_new_tokens": 512, "temperature": 0.2, **(params or {})}
with torch.inference_mode():
out = self.pipe(msgs, **params)
return out