|
|
import base64 |
|
|
import io |
|
|
import os |
|
|
import logging |
|
|
import requests |
|
|
import torch |
|
|
import transformers |
|
|
from PIL import Image |
|
|
from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoProcessor |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
logging.warning(f"[INIT] Transformers version: {transformers.__version__}") |
|
|
self.model_id = os.getenv("PULSE_MODEL_ID", "PULSE-ECG/PULSE-7B") |
|
|
|
|
|
|
|
|
try: |
|
|
self.pipe = pipeline( |
|
|
task="image-text-to-text", |
|
|
model=self.model_id, |
|
|
device_map="auto", |
|
|
torch_dtype="auto", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
logging.warning("[INIT] pipeline() loaded OK") |
|
|
self._ensure_pad_token() |
|
|
return |
|
|
except Exception as e: |
|
|
logging.warning(f"[INIT] pipeline() failed: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
cfg = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True) |
|
|
if getattr(cfg, "model_type", None) == "llava_llama": |
|
|
logging.warning("[INIT] llava_llama -> llava override") |
|
|
cfg.model_type = "llava" |
|
|
self.pipe = pipeline( |
|
|
task="image-text-to-text", |
|
|
model=self.model_id, |
|
|
device_map="auto", |
|
|
torch_dtype="auto", |
|
|
trust_remote_code=True, |
|
|
config=cfg, |
|
|
) |
|
|
logging.warning("[INIT] pipeline() loaded with config override") |
|
|
self._ensure_pad_token() |
|
|
return |
|
|
except Exception as e: |
|
|
logging.warning(f"[INIT] override failed: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
cfg = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True) |
|
|
if getattr(cfg, "model_type", None) == "llava_llama": |
|
|
logging.warning("[INIT] Fallback override: llava_llama -> llava") |
|
|
cfg.model_type = "llava" |
|
|
proc = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True, config=cfg) |
|
|
mdl = AutoModelForCausalLM.from_pretrained( |
|
|
self.model_id, |
|
|
device_map="auto", |
|
|
torch_dtype="auto", |
|
|
trust_remote_code=True, |
|
|
config=cfg, |
|
|
) |
|
|
|
|
|
def _mini_pipe(msgs, **params): |
|
|
inputs = proc(msgs, return_tensors="pt").to(mdl.device) |
|
|
gen_kwargs = {"max_new_tokens": 512, **params} |
|
|
with torch.inference_mode(): |
|
|
out_ids = mdl.generate(**inputs, **gen_kwargs) |
|
|
return proc.tokenizer.batch_decode(out_ids, skip_special_tokens=True) |
|
|
|
|
|
self.pipe = _mini_pipe |
|
|
logging.warning("[INIT] Fallback loaded") |
|
|
except Exception as e: |
|
|
logging.error(f"[INIT] Fallback failed: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
def _ensure_pad_token(self): |
|
|
try: |
|
|
if hasattr(self.pipe, "model"): |
|
|
gen_cfg = getattr(self.pipe.model, "generation_config", None) |
|
|
if gen_cfg and getattr(gen_cfg, "pad_token_id", None) is None: |
|
|
self.pipe.model.generation_config.pad_token_id = self.pipe.model.config.eos_token_id |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
def _normalize_inputs(self, data: dict): |
|
|
|
|
|
if "image_url" in data or "text" in data: |
|
|
image_url = data.get("image_url") |
|
|
text = data.get("text", "Interpret this ECG image.") |
|
|
if not image_url: |
|
|
raise ValueError("No image_url provided") |
|
|
return [ |
|
|
{"role": "user", "content": [ |
|
|
{"type": "image", "image_url": image_url}, |
|
|
{"type": "text", "text": text}, |
|
|
]} |
|
|
], data.get("parameters", {}) |
|
|
|
|
|
|
|
|
if "inputs" in data: |
|
|
return data.get("inputs", []), data.get("parameters", {}) |
|
|
|
|
|
raise ValueError("Invalid payload: expected 'image_url'+'text' or 'inputs'.") |
|
|
|
|
|
def __call__(self, data: dict): |
|
|
msgs, params = self._normalize_inputs(data) |
|
|
params = {"max_new_tokens": 512, "temperature": 0.2, **(params or {})} |
|
|
with torch.inference_mode(): |
|
|
out = self.pipe(msgs, **params) |
|
|
return out |
|
|
|