File size: 4,003 Bytes
20984f8
6f8b430
7f8404c
6f8b430
7f8404c
 
547115d
6f8b430
 
 
20984f8
6f8b430
 
 
 
 
 
 
 
 
 
 
 
20984f8
6f8b430
20984f8
6f8b430
 
 
20984f8
6f8b430
 
 
 
 
 
 
 
 
 
 
 
 
20984f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f8b430
547115d
6f8b430
 
 
 
547115d
 
 
 
6f8b430
547115d
 
 
 
 
 
 
 
 
 
 
 
6f8b430
547115d
 
 
20984f8
547115d
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import base64, io, os, logging
import requests, torch, transformers
from PIL import Image
from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoProcessor

class EndpointHandler:
    def __init__(self, path: str = ""):
        logging.warning(f"[INIT] Transformers version: {transformers.__version__}")
        self.model_id = os.getenv("PULSE_MODEL_ID", "PULSE-ECG/PULSE-7B")

        # 1) Normal yol: pipeline
        try:
            self.pipe = pipeline(
                task="image-text-to-text",
                model=self.model_id,
                device_map="auto",
                torch_dtype="auto",
                trust_remote_code=True,
            )
            logging.warning("[INIT] pipeline() loaded OK")
            self._ensure_pad_token()
            return
        except Exception as e:
            logging.warning(f"[INIT] pipeline() failed: {e}")

        # 2) llava_llama -> llava config override
        try:
            cfg = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
            if getattr(cfg, "model_type", None) == "llava_llama":
                logging.warning("[INIT] llava_llama -> llava override")
                cfg.model_type = "llava"
            self.pipe = pipeline(
                task="image-text-to-text",
                model=self.model_id,
                device_map="auto",
                torch_dtype="auto",
                trust_remote_code=True,
                config=cfg,
            )
            logging.warning("[INIT] pipeline() loaded with config override")
            self._ensure_pad_token()
            return
        except Exception as e:
            logging.warning(f"[INIT] override failed: {e}")

        # 3) Fallback: AutoProcessor + AutoModel
        logging.warning("[INIT] Fallback: AutoProcessor/AutoModel")
        proc = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True)
        mdl = AutoModelForCausalLM.from_pretrained(
            self.model_id,
            device_map="auto",
            torch_dtype="auto",
            trust_remote_code=True
        )

        def _mini_pipe(msgs, **params):
            inputs = proc(msgs, return_tensors="pt").to(mdl.device)
            gen_kwargs = {"max_new_tokens": 512, **params}
            with torch.inference_mode():
                out_ids = mdl.generate(**inputs, **gen_kwargs)
            return proc.tokenizer.batch_decode(out_ids, skip_special_tokens=True)

        self.pipe = _mini_pipe
        logging.warning("[INIT] Fallback loaded")

    # ---- helpers ----
    def _ensure_pad_token(self):
        try:
            if hasattr(self.pipe, "model"):
                gen_cfg = getattr(self.pipe.model, "generation_config", None)
                if gen_cfg and getattr(gen_cfg, "pad_token_id", None) is None:
                    self.pipe.model.generation_config.pad_token_id = self.pipe.model.config.eos_token_id
        except Exception:
            pass

    def _normalize_inputs(self, data: dict):
        # Basit şema
        if "image_url" in data or "text" in data:
            image_url = data.get("image_url")
            text = data.get("text", "Interpret this ECG image.")
            if not image_url:
                raise ValueError("No image_url provided")
            return [
                {"role": "user", "content": [
                    {"type": "image", "image_url": image_url},
                    {"type": "text", "text": text},
                ]}
            ], data.get("parameters", {})

        # Multimodal chat şeması
        if "inputs" in data:
            return data.get("inputs", []), data.get("parameters", {})

        raise ValueError("Invalid payload: expected 'image_url'+'text' or 'inputs'.")

    def __call__(self, data: dict):
        msgs, params = self._normalize_inputs(data)
        params = {"max_new_tokens": 512, "temperature": 0.2, **(params or {})}
        with torch.inference_mode():
            out = self.pipe(msgs, **params)
        return out