File size: 4,504 Bytes
5836cbd 7f8404c 6f8b430 7f8404c 547115d 6f8b430 5836cbd 6f8b430 20984f8 6f8b430 20984f8 6f8b430 20984f8 6f8b430 20984f8 5836cbd 20984f8 5836cbd 20984f8 5836cbd 20984f8 6f8b430 547115d 6f8b430 547115d 5836cbd 547115d 5836cbd 547115d 20984f8 547115d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import base64
import io
import os
import logging
import requests
import torch
import transformers
from PIL import Image
from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoProcessor
class EndpointHandler:
def __init__(self, path: str = ""):
logging.warning(f"[INIT] Transformers version: {transformers.__version__}")
self.model_id = os.getenv("PULSE_MODEL_ID", "PULSE-ECG/PULSE-7B")
# 1) Normal path: attempt pipeline directly
try:
self.pipe = pipeline(
task="image-text-to-text",
model=self.model_id,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True,
)
logging.warning("[INIT] pipeline() loaded OK")
self._ensure_pad_token()
return
except Exception as e:
logging.warning(f"[INIT] pipeline() failed: {e}")
# 2) llava_llama -> llava config override
try:
cfg = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
if getattr(cfg, "model_type", None) == "llava_llama":
logging.warning("[INIT] llava_llama -> llava override")
cfg.model_type = "llava"
self.pipe = pipeline(
task="image-text-to-text",
model=self.model_id,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True,
config=cfg,
)
logging.warning("[INIT] pipeline() loaded with config override")
self._ensure_pad_token()
return
except Exception as e:
logging.warning(f"[INIT] override failed: {e}")
# 3) Fallback: AutoProcessor + AutoModel with config override check
try:
cfg = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
if getattr(cfg, "model_type", None) == "llava_llama":
logging.warning("[INIT] Fallback override: llava_llama -> llava")
cfg.model_type = "llava"
proc = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True, config=cfg)
mdl = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True,
config=cfg,
)
def _mini_pipe(msgs, **params):
inputs = proc(msgs, return_tensors="pt").to(mdl.device)
gen_kwargs = {"max_new_tokens": 512, **params}
with torch.inference_mode():
out_ids = mdl.generate(**inputs, **gen_kwargs)
return proc.tokenizer.batch_decode(out_ids, skip_special_tokens=True)
self.pipe = _mini_pipe
logging.warning("[INIT] Fallback loaded")
except Exception as e:
logging.error(f"[INIT] Fallback failed: {e}")
raise
# ---- helpers ----
def _ensure_pad_token(self):
try:
if hasattr(self.pipe, "model"):
gen_cfg = getattr(self.pipe.model, "generation_config", None)
if gen_cfg and getattr(gen_cfg, "pad_token_id", None) is None:
self.pipe.model.generation_config.pad_token_id = self.pipe.model.config.eos_token_id
except Exception:
pass
def _normalize_inputs(self, data: dict):
# Simple schema
if "image_url" in data or "text" in data:
image_url = data.get("image_url")
text = data.get("text", "Interpret this ECG image.")
if not image_url:
raise ValueError("No image_url provided")
return [
{"role": "user", "content": [
{"type": "image", "image_url": image_url},
{"type": "text", "text": text},
]}
], data.get("parameters", {})
# Multimodal chat schema
if "inputs" in data:
return data.get("inputs", []), data.get("parameters", {})
raise ValueError("Invalid payload: expected 'image_url'+'text' or 'inputs'.")
def __call__(self, data: dict):
msgs, params = self._normalize_inputs(data)
params = {"max_new_tokens": 512, "temperature": 0.2, **(params or {})}
with torch.inference_mode():
out = self.pipe(msgs, **params)
return out
|