File size: 4,504 Bytes
5836cbd
 
 
 
 
 
 
7f8404c
6f8b430
7f8404c
 
547115d
6f8b430
 
 
5836cbd
6f8b430
 
 
 
 
 
 
 
 
 
 
 
20984f8
6f8b430
20984f8
6f8b430
 
 
20984f8
6f8b430
 
 
 
 
 
 
 
 
 
 
 
 
20984f8
 
5836cbd
 
 
 
 
 
 
 
 
 
 
 
 
 
20984f8
5836cbd
 
 
 
 
 
20984f8
5836cbd
 
 
 
 
20984f8
 
6f8b430
547115d
6f8b430
 
 
 
547115d
 
 
 
5836cbd
547115d
 
 
 
 
 
 
 
 
 
 
 
5836cbd
547115d
 
 
20984f8
547115d
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import base64
import io
import os
import logging
import requests
import torch
import transformers
from PIL import Image
from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoProcessor

class EndpointHandler:
    def __init__(self, path: str = ""):
        logging.warning(f"[INIT] Transformers version: {transformers.__version__}")
        self.model_id = os.getenv("PULSE_MODEL_ID", "PULSE-ECG/PULSE-7B")

        # 1) Normal path: attempt pipeline directly
        try:
            self.pipe = pipeline(
                task="image-text-to-text",
                model=self.model_id,
                device_map="auto",
                torch_dtype="auto",
                trust_remote_code=True,
            )
            logging.warning("[INIT] pipeline() loaded OK")
            self._ensure_pad_token()
            return
        except Exception as e:
            logging.warning(f"[INIT] pipeline() failed: {e}")

        # 2) llava_llama -> llava config override
        try:
            cfg = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
            if getattr(cfg, "model_type", None) == "llava_llama":
                logging.warning("[INIT] llava_llama -> llava override")
                cfg.model_type = "llava"
            self.pipe = pipeline(
                task="image-text-to-text",
                model=self.model_id,
                device_map="auto",
                torch_dtype="auto",
                trust_remote_code=True,
                config=cfg,
            )
            logging.warning("[INIT] pipeline() loaded with config override")
            self._ensure_pad_token()
            return
        except Exception as e:
            logging.warning(f"[INIT] override failed: {e}")

        # 3) Fallback: AutoProcessor + AutoModel with config override check
        try:
            cfg = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
            if getattr(cfg, "model_type", None) == "llava_llama":
                logging.warning("[INIT] Fallback override: llava_llama -> llava")
                cfg.model_type = "llava"
            proc = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True, config=cfg)
            mdl = AutoModelForCausalLM.from_pretrained(
                self.model_id,
                device_map="auto",
                torch_dtype="auto",
                trust_remote_code=True,
                config=cfg,
            )

            def _mini_pipe(msgs, **params):
                inputs = proc(msgs, return_tensors="pt").to(mdl.device)
                gen_kwargs = {"max_new_tokens": 512, **params}
                with torch.inference_mode():
                    out_ids = mdl.generate(**inputs, **gen_kwargs)
                return proc.tokenizer.batch_decode(out_ids, skip_special_tokens=True)

            self.pipe = _mini_pipe
            logging.warning("[INIT] Fallback loaded")
        except Exception as e:
            logging.error(f"[INIT] Fallback failed: {e}")
            raise

    # ---- helpers ----
    def _ensure_pad_token(self):
        try:
            if hasattr(self.pipe, "model"):
                gen_cfg = getattr(self.pipe.model, "generation_config", None)
                if gen_cfg and getattr(gen_cfg, "pad_token_id", None) is None:
                    self.pipe.model.generation_config.pad_token_id = self.pipe.model.config.eos_token_id
        except Exception:
            pass

    def _normalize_inputs(self, data: dict):
        # Simple schema
        if "image_url" in data or "text" in data:
            image_url = data.get("image_url")
            text = data.get("text", "Interpret this ECG image.")
            if not image_url:
                raise ValueError("No image_url provided")
            return [
                {"role": "user", "content": [
                    {"type": "image", "image_url": image_url},
                    {"type": "text", "text": text},
                ]}
            ], data.get("parameters", {})

        # Multimodal chat schema
        if "inputs" in data:
            return data.get("inputs", []), data.get("parameters", {})

        raise ValueError("Invalid payload: expected 'image_url'+'text' or 'inputs'.")

    def __call__(self, data: dict):
        msgs, params = self._normalize_inputs(data)
        params = {"max_new_tokens": 512, "temperature": 0.2, **(params or {})}
        with torch.inference_mode():
            out = self.pipe(msgs, **params)
        return out