File size: 5,066 Bytes
8af47a3
 
11368a6
 
8af47a3
 
11368a6
8af47a3
11368a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8af47a3
 
11368a6
8af47a3
11368a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8af47a3
 
11368a6
8af47a3
11368a6
 
 
 
 
 
 
8af47a3
11368a6
8af47a3
11368a6
8af47a3
 
 
 
5dd7eca
 
 
 
 
 
 
 
 
 
 
8af47a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11368a6
 
 
8af47a3
 
 
 
 
d8aef83
8af47a3
 
 
 
 
 
 
 
 
 
 
 
 
a5b3750
11368a6
8af47a3
 
 
 
 
 
eb170be
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import gradio as gr
import torch
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, BitsAndBytesConfig
from peft import PeftModel
from PIL import Image
import json
import os

DEFAULT_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
DEFAULT_ADAPTER_ID = "hssling/cardioai-adapter"
CONFIG_PATH = "model_config.json"

def load_runtime_config():
    config = {
        "base_model": os.environ.get("BASE_MODEL_ID", DEFAULT_MODEL_ID),
        "adapter_repo": os.environ.get("ADAPTER_REPO_ID", DEFAULT_ADAPTER_ID),
        "adapter_revision": os.environ.get("ADAPTER_REVISION", "main")
    }
    if os.path.exists(CONFIG_PATH):
        try:
            with open(CONFIG_PATH, "r", encoding="utf-8") as f:
                disk_cfg = json.load(f)
            config["base_model"] = disk_cfg.get("base_model", config["base_model"])
            config["adapter_repo"] = disk_cfg.get("adapter_repo", config["adapter_repo"])
            config["adapter_revision"] = disk_cfg.get("adapter_revision", config["adapter_revision"])
        except Exception as e:
            print(f"Failed to read {CONFIG_PATH}; falling back to defaults. Error: {e}")
    return config

cfg = load_runtime_config()
MODEL_ID = cfg["base_model"]
ADAPTER_ID = cfg["adapter_repo"]
ADAPTER_REV = cfg["adapter_revision"]

print("Starting App Engine...")
os.makedirs("/tmp/offload", exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=False)

model_kwargs = {
    "pretrained_model_name_or_path": MODEL_ID,
    "device_map": "auto",
    "low_cpu_mem_usage": True,
    "offload_folder": "/tmp/offload"
}

if device == "cuda":
    model_kwargs["torch_dtype"] = torch.float16
    model_kwargs["quantization_config"] = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True
    )
else:
    # CPU space: keep dtype low to reduce memory footprint.
    model_kwargs["torch_dtype"] = torch.float16

model = Qwen2VLForConditionalGeneration.from_pretrained(**model_kwargs)

if ADAPTER_ID:
    print(f"Loading custom fine-tuned LoRA weights: {ADAPTER_ID}@{ADAPTER_REV}")
    try:
        model = PeftModel.from_pretrained(
            model,
            ADAPTER_ID,
            revision=ADAPTER_REV,
            is_trainable=False
        )
        print("Adapter load successful.")
    except Exception as e:
        print(f"Failed to load adapter; serving base model instead. Error: {e}")

def diagnose_ecg(image: Image.Image = None, temp: float = 0.4, max_tokens: int = 768):
    try:
        if image is None:
            return json.dumps({"error": "No image provided."})

        system_prompt = (
            "You are CardioAI, an ECG interpretation engine. "
            "Always analyze the provided ECG image directly. "
            "Do not provide generic AI disclaimers. "
            "Return concise clinical content only."
        )
        user_prompt = (
            "Interpret this ECG image and return exactly these sections: "
            "1) Impression, 2) Rhythm, 3) Rate, 4) ST-T Findings, 5) Urgency. "
            "If image quality is insufficient, write 'Non-diagnostic ECG image quality' in Impression."
        )

        messages = [
            {"role": "system", "content": system_prompt},
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": user_prompt}
                ]
            }
        ]

        text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        
        inputs = processor(
            text=[text_input],
            images=[image],
            padding=True,
            return_tensors="pt"
        )
        model_device = model.device if hasattr(model, "device") else torch.device(device)
        inputs = {k: v.to(model_device) for k, v in inputs.items()}

        with torch.no_grad():
            generated_ids = model.generate(**inputs, max_new_tokens=int(max_tokens), temperature=float(temp), top_p=0.9, do_sample=True)

        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
        ]
        
        output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

        return output_text

    except Exception as e:
        return f"Error: {str(e)}"

demo = gr.Interface(
    fn=diagnose_ecg,
    inputs=[
        gr.Image(type="pil", label="ECG Image Scan"),
        gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1, label="Temperature"),
        gr.Slider(minimum=128, maximum=1536, value=768, step=128, label="Max Tokens")
    ],
    outputs=gr.Markdown(label="Clinical Report Output"),
    title="CardioAI Inference API",
    description="Fine-tuned Medical LLM for Electrocardiogram (ECG) Tracings."
)

if __name__ == "__main__":
    demo.launch()