|
|
|
|
|
""" |
|
|
ECG AI7 - ECG Interpretation using Llama 3.2 11B Vision |
|
|
Gradio interface for Hugging Face Spaces |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import MllamaForConditionalGeneration, AutoProcessor, TextStreamer, BitsAndBytesConfig |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
MODEL_ID = "convaiinnovations/ECG-Instruct-Llama-3.2-11B-Vision" |
|
|
|
|
|
print(f"Loading model: {MODEL_ID}") |
|
|
print("Loading in 4-bit mode to fit in free tier memory (16GB)...") |
|
|
print("This may take a few minutes on first load...") |
|
|
|
|
|
|
|
|
|
|
|
quantization_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=torch.float16, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
model = MllamaForConditionalGeneration.from_pretrained( |
|
|
MODEL_ID, |
|
|
quantization_config=quantization_config, |
|
|
device_map="auto", |
|
|
) |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(MODEL_ID) |
|
|
|
|
|
print("Model loaded successfully in 4-bit mode!") |
|
|
|
|
|
|
|
|
|
|
|
def _strip_assistant_prefix_safe(s: str) -> str: |
|
|
"""Safely strip assistant prefix from generated text""" |
|
|
s = s.lstrip() |
|
|
|
|
|
for prefix in ("user", "assistant", "User", "Assistant"): |
|
|
if s.startswith(prefix): |
|
|
idx = s.find("\n\n") |
|
|
if idx != -1: |
|
|
return s[idx+2:].lstrip() |
|
|
idx = s.find("\n") |
|
|
if idx != -1: |
|
|
return s[idx+1:].lstrip() |
|
|
return s |
|
|
|
|
|
|
|
|
def generate_full_report(image_path: str, query: str, *, |
|
|
max_new_tokens: int = 1600, |
|
|
do_stream: bool = False, |
|
|
temperature: float = 0.0) -> str: |
|
|
""" |
|
|
Generate ECG interpretation report |
|
|
|
|
|
Args: |
|
|
image_path: local path to ECG image |
|
|
query: instruction string for the model |
|
|
max_new_tokens: maximum tokens to generate |
|
|
do_stream: whether to stream output (for terminal use) |
|
|
temperature: sampling temperature (0.0 = greedy) |
|
|
|
|
|
Returns: |
|
|
Full decoded interpretation report |
|
|
""" |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "user", "content": [ |
|
|
{"type": "image"}, |
|
|
{"type": "text", "text": query} |
|
|
]} |
|
|
] |
|
|
|
|
|
|
|
|
input_text = processor.apply_chat_template(messages, add_generation_prompt=True) |
|
|
inputs = processor(text=input_text, images=image, return_tensors="pt") |
|
|
|
|
|
|
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
streamer = TextStreamer(processor.tokenizer, skip_prompt=True) if do_stream else None |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
out_ids = model.generate( |
|
|
**inputs, |
|
|
streamer=streamer, |
|
|
max_new_tokens=max_new_tokens, |
|
|
use_cache=True, |
|
|
do_sample=False if temperature == 0.0 else True, |
|
|
temperature=temperature, |
|
|
top_p=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
full_raw = processor.batch_decode(out_ids, skip_special_tokens=True)[0] |
|
|
full_clean = _strip_assistant_prefix_safe(full_raw) |
|
|
return full_clean |
|
|
|
|
|
|
|
|
def translate_to_farsi(english_text: str, max_new_tokens: int = 1600) -> str: |
|
|
"""Translate English text to Persian using the same model""" |
|
|
msgs = [ |
|
|
{"role": "user", "content": [ |
|
|
{"type": "text", |
|
|
"text": "فقط متن زیر را به فارسی روان ترجمه کن و فقط ترجمه را برگردان:\n\n" + english_text} |
|
|
]} |
|
|
] |
|
|
prompt = processor.apply_chat_template(msgs, add_generation_prompt=True) |
|
|
inputs = processor(text=prompt, return_tensors="pt") |
|
|
|
|
|
|
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
out = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=False, |
|
|
temperature=0.0, |
|
|
top_p=1.0 |
|
|
) |
|
|
|
|
|
ans = processor.batch_decode(out, skip_special_tokens=True)[0] |
|
|
return _strip_assistant_prefix_safe(ans) |
|
|
|
|
|
|
|
|
|
|
|
def analyze_ecg_gradio(image, text_instruction="", language="Farsi"): |
|
|
""" |
|
|
Main function for Gradio interface |
|
|
|
|
|
Args: |
|
|
image: uploaded ECG image filepath (string path) |
|
|
text_instruction: optional clinical note / context |
|
|
language: output language (English or Farsi) |
|
|
|
|
|
Returns: |
|
|
Full AI-generated ECG interpretation report |
|
|
""" |
|
|
try: |
|
|
print(f"Received image: {image}") |
|
|
print(f"Text instruction: {text_instruction}") |
|
|
print(f"Language: {language}") |
|
|
|
|
|
|
|
|
query = "You are an expert cardiologist. " |
|
|
if text_instruction: |
|
|
query += f"Patient info: {text_instruction}. " |
|
|
query += "Write an in-depth diagnosis report from this ECG data, including the final diagnosis." |
|
|
|
|
|
|
|
|
print("Generating report in English...") |
|
|
report = generate_full_report(image, query, max_new_tokens=1600, do_stream=False) |
|
|
|
|
|
|
|
|
if language == "Farsi": |
|
|
print("Translating to Farsi...") |
|
|
report = translate_to_farsi(report, max_new_tokens=1600) |
|
|
|
|
|
print("Report generated successfully!") |
|
|
return report |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
|
|
print(error_msg) |
|
|
return error_msg |
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=analyze_ecg_gradio, |
|
|
inputs=[ |
|
|
gr.Image(type="filepath", label="ECG Image"), |
|
|
gr.Textbox(lines=2, label="Clinical Note (Optional)"), |
|
|
gr.Dropdown(choices=["English", "Farsi"], value="Farsi", label="Language"), |
|
|
], |
|
|
outputs=gr.Textbox(label="ECG Report", lines=15), |
|
|
title="ECG AI7 Backend", |
|
|
description="ECG interpretation API powered by Llama 3.2 11B Vision", |
|
|
flagging_mode="never", |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|