Rapid_ECG / handler.py
ismailhakki37's picture
update handler py
5836cbd verified
raw
history blame
4.5 kB
import base64
import io
import os
import logging
import requests
import torch
import transformers
from PIL import Image
from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoProcessor
class EndpointHandler:
def __init__(self, path: str = ""):
logging.warning(f"[INIT] Transformers version: {transformers.__version__}")
self.model_id = os.getenv("PULSE_MODEL_ID", "PULSE-ECG/PULSE-7B")
# 1) Normal path: attempt pipeline directly
try:
self.pipe = pipeline(
task="image-text-to-text",
model=self.model_id,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True,
)
logging.warning("[INIT] pipeline() loaded OK")
self._ensure_pad_token()
return
except Exception as e:
logging.warning(f"[INIT] pipeline() failed: {e}")
# 2) llava_llama -> llava config override
try:
cfg = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
if getattr(cfg, "model_type", None) == "llava_llama":
logging.warning("[INIT] llava_llama -> llava override")
cfg.model_type = "llava"
self.pipe = pipeline(
task="image-text-to-text",
model=self.model_id,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True,
config=cfg,
)
logging.warning("[INIT] pipeline() loaded with config override")
self._ensure_pad_token()
return
except Exception as e:
logging.warning(f"[INIT] override failed: {e}")
# 3) Fallback: AutoProcessor + AutoModel with config override check
try:
cfg = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
if getattr(cfg, "model_type", None) == "llava_llama":
logging.warning("[INIT] Fallback override: llava_llama -> llava")
cfg.model_type = "llava"
proc = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True, config=cfg)
mdl = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True,
config=cfg,
)
def _mini_pipe(msgs, **params):
inputs = proc(msgs, return_tensors="pt").to(mdl.device)
gen_kwargs = {"max_new_tokens": 512, **params}
with torch.inference_mode():
out_ids = mdl.generate(**inputs, **gen_kwargs)
return proc.tokenizer.batch_decode(out_ids, skip_special_tokens=True)
self.pipe = _mini_pipe
logging.warning("[INIT] Fallback loaded")
except Exception as e:
logging.error(f"[INIT] Fallback failed: {e}")
raise
# ---- helpers ----
def _ensure_pad_token(self):
try:
if hasattr(self.pipe, "model"):
gen_cfg = getattr(self.pipe.model, "generation_config", None)
if gen_cfg and getattr(gen_cfg, "pad_token_id", None) is None:
self.pipe.model.generation_config.pad_token_id = self.pipe.model.config.eos_token_id
except Exception:
pass
def _normalize_inputs(self, data: dict):
# Simple schema
if "image_url" in data or "text" in data:
image_url = data.get("image_url")
text = data.get("text", "Interpret this ECG image.")
if not image_url:
raise ValueError("No image_url provided")
return [
{"role": "user", "content": [
{"type": "image", "image_url": image_url},
{"type": "text", "text": text},
]}
], data.get("parameters", {})
# Multimodal chat schema
if "inputs" in data:
return data.get("inputs", []), data.get("parameters", {})
raise ValueError("Invalid payload: expected 'image_url'+'text' or 'inputs'.")
def __call__(self, data: dict):
msgs, params = self._normalize_inputs(data)
params = {"max_new_tokens": 512, "temperature": 0.2, **(params or {})}
with torch.inference_mode():
out = self.pipe(msgs, **params)
return out