File size: 4,003 Bytes
20984f8 6f8b430 7f8404c 6f8b430 7f8404c 547115d 6f8b430 20984f8 6f8b430 20984f8 6f8b430 20984f8 6f8b430 20984f8 6f8b430 20984f8 6f8b430 547115d 6f8b430 547115d 6f8b430 547115d 6f8b430 547115d 20984f8 547115d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import base64, io, os, logging
import requests, torch, transformers
from PIL import Image
from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoProcessor
class EndpointHandler:
def __init__(self, path: str = ""):
logging.warning(f"[INIT] Transformers version: {transformers.__version__}")
self.model_id = os.getenv("PULSE_MODEL_ID", "PULSE-ECG/PULSE-7B")
# 1) Normal yol: pipeline
try:
self.pipe = pipeline(
task="image-text-to-text",
model=self.model_id,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True,
)
logging.warning("[INIT] pipeline() loaded OK")
self._ensure_pad_token()
return
except Exception as e:
logging.warning(f"[INIT] pipeline() failed: {e}")
# 2) llava_llama -> llava config override
try:
cfg = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
if getattr(cfg, "model_type", None) == "llava_llama":
logging.warning("[INIT] llava_llama -> llava override")
cfg.model_type = "llava"
self.pipe = pipeline(
task="image-text-to-text",
model=self.model_id,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True,
config=cfg,
)
logging.warning("[INIT] pipeline() loaded with config override")
self._ensure_pad_token()
return
except Exception as e:
logging.warning(f"[INIT] override failed: {e}")
# 3) Fallback: AutoProcessor + AutoModel
logging.warning("[INIT] Fallback: AutoProcessor/AutoModel")
proc = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True)
mdl = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True
)
def _mini_pipe(msgs, **params):
inputs = proc(msgs, return_tensors="pt").to(mdl.device)
gen_kwargs = {"max_new_tokens": 512, **params}
with torch.inference_mode():
out_ids = mdl.generate(**inputs, **gen_kwargs)
return proc.tokenizer.batch_decode(out_ids, skip_special_tokens=True)
self.pipe = _mini_pipe
logging.warning("[INIT] Fallback loaded")
# ---- helpers ----
def _ensure_pad_token(self):
try:
if hasattr(self.pipe, "model"):
gen_cfg = getattr(self.pipe.model, "generation_config", None)
if gen_cfg and getattr(gen_cfg, "pad_token_id", None) is None:
self.pipe.model.generation_config.pad_token_id = self.pipe.model.config.eos_token_id
except Exception:
pass
def _normalize_inputs(self, data: dict):
# Basit şema
if "image_url" in data or "text" in data:
image_url = data.get("image_url")
text = data.get("text", "Interpret this ECG image.")
if not image_url:
raise ValueError("No image_url provided")
return [
{"role": "user", "content": [
{"type": "image", "image_url": image_url},
{"type": "text", "text": text},
]}
], data.get("parameters", {})
# Multimodal chat şeması
if "inputs" in data:
return data.get("inputs", []), data.get("parameters", {})
raise ValueError("Invalid payload: expected 'image_url'+'text' or 'inputs'.")
def __call__(self, data: dict):
msgs, params = self._normalize_inputs(data)
params = {"max_new_tokens": 512, "temperature": 0.2, **(params or {})}
with torch.inference_mode():
out = self.pipe(msgs, **params)
return out
|