File size: 6,499 Bytes
7040489
 
 
 
 
 
 
b7d24ee
7040489
 
 
 
5d9d36f
a7d3992
7040489
b7d24ee
7040489
 
b7d24ee
 
 
 
 
 
 
 
 
 
 
7040489
 
b7d24ee
7040489
 
 
5d9d36f
7040489
b7d24ee
7040489
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f53e8d
7040489
 
 
 
3f53e8d
 
7040489
3f53e8d
 
 
 
7040489
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
#!/usr/bin/env python3
"""
ECG AI7 - ECG Interpretation using Llama 3.2 11B Vision
Gradio interface for Hugging Face Spaces
"""

import torch
from transformers import MllamaForConditionalGeneration, AutoProcessor, TextStreamer, BitsAndBytesConfig
from PIL import Image
import gradio as gr

# Model configuration
MODEL_ID = "convaiinnovations/ECG-Instruct-Llama-3.2-11B-Vision"

print(f"Loading model: {MODEL_ID}")
print("Loading in 4-bit mode to fit in free tier memory (16GB)...")
print("This may take a few minutes on first load...")

# Configure 4-bit quantization properly using BitsAndBytesConfig
# This is more stable than deprecated load_in_4bit parameter
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

# Load model and processor with 4-bit quantization to reduce memory significantly
# This allows the 11B model to run on free tier (16GB GPU)
model = MllamaForConditionalGeneration.from_pretrained(
    MODEL_ID,
    quantization_config=quantization_config,
    device_map="auto",
)

processor = AutoProcessor.from_pretrained(MODEL_ID)

print("Model loaded successfully in 4-bit mode!")


# Helper functions
def _strip_assistant_prefix_safe(s: str) -> str:
    """Safely strip assistant prefix from generated text"""
    s = s.lstrip()
    # Only remove a leading role block if it literally starts the text
    for prefix in ("user", "assistant", "User", "Assistant"):
        if s.startswith(prefix):
            idx = s.find("\n\n")
            if idx != -1:
                return s[idx+2:].lstrip()
            idx = s.find("\n")
            if idx != -1:
                return s[idx+1:].lstrip()
    return s


def generate_full_report(image_path: str, query: str, *,
                         max_new_tokens: int = 1600,
                         do_stream: bool = False,
                         temperature: float = 0.0) -> str:
    """
    Generate ECG interpretation report

    Args:
        image_path: local path to ECG image
        query: instruction string for the model
        max_new_tokens: maximum tokens to generate
        do_stream: whether to stream output (for terminal use)
        temperature: sampling temperature (0.0 = greedy)

    Returns:
        Full decoded interpretation report
    """
    image = Image.open(image_path).convert("RGB")

    # Build single user turn: image + text
    messages = [
        {"role": "user", "content": [
            {"type": "image"},
            {"type": "text", "text": query}
        ]}
    ]

    # Create prompt compatible with processor
    input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(text=input_text, images=image, return_tensors="pt")

    # Move inputs to same device as model
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Setup streamer if requested
    streamer = TextStreamer(processor.tokenizer, skip_prompt=True) if do_stream else None

    # Generate
    with torch.no_grad():
        out_ids = model.generate(
            **inputs,
            streamer=streamer,
            max_new_tokens=max_new_tokens,
            use_cache=True,
            do_sample=False if temperature == 0.0 else True,
            temperature=temperature,
            top_p=1.0,
        )

    # Decode full generated text
    full_raw = processor.batch_decode(out_ids, skip_special_tokens=True)[0]
    full_clean = _strip_assistant_prefix_safe(full_raw)
    return full_clean


def translate_to_farsi(english_text: str, max_new_tokens: int = 1600) -> str:
    """Translate English text to Persian using the same model"""
    msgs = [
        {"role": "user", "content": [
            {"type": "text",
             "text": "فقط متن زیر را به فارسی روان ترجمه کن و فقط ترجمه را برگردان:\n\n" + english_text}
        ]}
    ]
    prompt = processor.apply_chat_template(msgs, add_generation_prompt=True)
    inputs = processor(text=prompt, return_tensors="pt")

    # Move to device
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=0.0,
            top_p=1.0
        )

    ans = processor.batch_decode(out, skip_special_tokens=True)[0]
    return _strip_assistant_prefix_safe(ans)


# Gradio interface function
def analyze_ecg_gradio(image, text_instruction="", language="Farsi"):
    """
    Main function for Gradio interface

    Args:
        image: uploaded ECG image filepath (string path)
        text_instruction: optional clinical note / context
        language: output language (English or Farsi)

    Returns:
        Full AI-generated ECG interpretation report
    """
    try:
        print(f"Received image: {image}")
        print(f"Text instruction: {text_instruction}")
        print(f"Language: {language}")

        # Build query
        query = "You are an expert cardiologist. "
        if text_instruction:
            query += f"Patient info: {text_instruction}. "
        query += "Write an in-depth diagnosis report from this ECG data, including the final diagnosis."

        # Generate report in English
        print("Generating report in English...")
        report = generate_full_report(image, query, max_new_tokens=1600, do_stream=False)

        # Translate to Farsi if requested
        if language == "Farsi":
            print("Translating to Farsi...")
            report = translate_to_farsi(report, max_new_tokens=1600)

        print("Report generated successfully!")
        return report

    except Exception as e:
        import traceback
        error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
        print(error_msg)
        return error_msg


# Create minimal Gradio interface for API backend
demo = gr.Interface(
    fn=analyze_ecg_gradio,
    inputs=[
        gr.Image(type="filepath", label="ECG Image"),
        gr.Textbox(lines=2, label="Clinical Note (Optional)"),
        gr.Dropdown(choices=["English", "Farsi"], value="Farsi", label="Language"),
    ],
    outputs=gr.Textbox(label="ECG Report", lines=15),
    title="ECG AI7 Backend",
    description="ECG interpretation API powered by Llama 3.2 11B Vision",
    flagging_mode="never",
)

# Launch the app
if __name__ == "__main__":
    demo.launch()