File size: 6,499 Bytes
7040489 b7d24ee 7040489 5d9d36f a7d3992 7040489 b7d24ee 7040489 b7d24ee 7040489 b7d24ee 7040489 5d9d36f 7040489 b7d24ee 7040489 3f53e8d 7040489 3f53e8d 7040489 3f53e8d 7040489 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
#!/usr/bin/env python3
"""
ECG AI7 - ECG Interpretation using Llama 3.2 11B Vision
Gradio interface for Hugging Face Spaces
"""
import torch
from transformers import MllamaForConditionalGeneration, AutoProcessor, TextStreamer, BitsAndBytesConfig
from PIL import Image
import gradio as gr
# Model configuration
MODEL_ID = "convaiinnovations/ECG-Instruct-Llama-3.2-11B-Vision"
print(f"Loading model: {MODEL_ID}")
print("Loading in 4-bit mode to fit in free tier memory (16GB)...")
print("This may take a few minutes on first load...")
# Configure 4-bit quantization properly using BitsAndBytesConfig
# This is more stable than deprecated load_in_4bit parameter
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
# Load model and processor with 4-bit quantization to reduce memory significantly
# This allows the 11B model to run on free tier (16GB GPU)
model = MllamaForConditionalGeneration.from_pretrained(
MODEL_ID,
quantization_config=quantization_config,
device_map="auto",
)
processor = AutoProcessor.from_pretrained(MODEL_ID)
print("Model loaded successfully in 4-bit mode!")
# Helper functions
def _strip_assistant_prefix_safe(s: str) -> str:
"""Safely strip assistant prefix from generated text"""
s = s.lstrip()
# Only remove a leading role block if it literally starts the text
for prefix in ("user", "assistant", "User", "Assistant"):
if s.startswith(prefix):
idx = s.find("\n\n")
if idx != -1:
return s[idx+2:].lstrip()
idx = s.find("\n")
if idx != -1:
return s[idx+1:].lstrip()
return s
def generate_full_report(image_path: str, query: str, *,
max_new_tokens: int = 1600,
do_stream: bool = False,
temperature: float = 0.0) -> str:
"""
Generate ECG interpretation report
Args:
image_path: local path to ECG image
query: instruction string for the model
max_new_tokens: maximum tokens to generate
do_stream: whether to stream output (for terminal use)
temperature: sampling temperature (0.0 = greedy)
Returns:
Full decoded interpretation report
"""
image = Image.open(image_path).convert("RGB")
# Build single user turn: image + text
messages = [
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": query}
]}
]
# Create prompt compatible with processor
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=input_text, images=image, return_tensors="pt")
# Move inputs to same device as model
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Setup streamer if requested
streamer = TextStreamer(processor.tokenizer, skip_prompt=True) if do_stream else None
# Generate
with torch.no_grad():
out_ids = model.generate(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
use_cache=True,
do_sample=False if temperature == 0.0 else True,
temperature=temperature,
top_p=1.0,
)
# Decode full generated text
full_raw = processor.batch_decode(out_ids, skip_special_tokens=True)[0]
full_clean = _strip_assistant_prefix_safe(full_raw)
return full_clean
def translate_to_farsi(english_text: str, max_new_tokens: int = 1600) -> str:
"""Translate English text to Persian using the same model"""
msgs = [
{"role": "user", "content": [
{"type": "text",
"text": "فقط متن زیر را به فارسی روان ترجمه کن و فقط ترجمه را برگردان:\n\n" + english_text}
]}
]
prompt = processor.apply_chat_template(msgs, add_generation_prompt=True)
inputs = processor(text=prompt, return_tensors="pt")
# Move to device
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=0.0,
top_p=1.0
)
ans = processor.batch_decode(out, skip_special_tokens=True)[0]
return _strip_assistant_prefix_safe(ans)
# Gradio interface function
def analyze_ecg_gradio(image, text_instruction="", language="Farsi"):
"""
Main function for Gradio interface
Args:
image: uploaded ECG image filepath (string path)
text_instruction: optional clinical note / context
language: output language (English or Farsi)
Returns:
Full AI-generated ECG interpretation report
"""
try:
print(f"Received image: {image}")
print(f"Text instruction: {text_instruction}")
print(f"Language: {language}")
# Build query
query = "You are an expert cardiologist. "
if text_instruction:
query += f"Patient info: {text_instruction}. "
query += "Write an in-depth diagnosis report from this ECG data, including the final diagnosis."
# Generate report in English
print("Generating report in English...")
report = generate_full_report(image, query, max_new_tokens=1600, do_stream=False)
# Translate to Farsi if requested
if language == "Farsi":
print("Translating to Farsi...")
report = translate_to_farsi(report, max_new_tokens=1600)
print("Report generated successfully!")
return report
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
print(error_msg)
return error_msg
# Create minimal Gradio interface for API backend
demo = gr.Interface(
fn=analyze_ecg_gradio,
inputs=[
gr.Image(type="filepath", label="ECG Image"),
gr.Textbox(lines=2, label="Clinical Note (Optional)"),
gr.Dropdown(choices=["English", "Farsi"], value="Farsi", label="Language"),
],
outputs=gr.Textbox(label="ECG Report", lines=15),
title="ECG AI7 Backend",
description="ECG interpretation API powered by Llama 3.2 11B Vision",
flagging_mode="never",
)
# Launch the app
if __name__ == "__main__":
demo.launch()
|