AKSazgar's picture
Fix 8-bit quantization bug, switch to 4-bit with BitsAndBytesConfig
b7d24ee
#!/usr/bin/env python3
"""
ECG AI7 - ECG Interpretation using Llama 3.2 11B Vision
Gradio interface for Hugging Face Spaces
"""
import torch
from transformers import MllamaForConditionalGeneration, AutoProcessor, TextStreamer, BitsAndBytesConfig
from PIL import Image
import gradio as gr
# Model configuration
MODEL_ID = "convaiinnovations/ECG-Instruct-Llama-3.2-11B-Vision"
print(f"Loading model: {MODEL_ID}")
print("Loading in 4-bit mode to fit in free tier memory (16GB)...")
print("This may take a few minutes on first load...")
# Configure 4-bit quantization properly using BitsAndBytesConfig
# This is more stable than deprecated load_in_4bit parameter
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
# Load model and processor with 4-bit quantization to reduce memory significantly
# This allows the 11B model to run on free tier (16GB GPU)
model = MllamaForConditionalGeneration.from_pretrained(
MODEL_ID,
quantization_config=quantization_config,
device_map="auto",
)
processor = AutoProcessor.from_pretrained(MODEL_ID)
print("Model loaded successfully in 4-bit mode!")
# Helper functions
def _strip_assistant_prefix_safe(s: str) -> str:
"""Safely strip assistant prefix from generated text"""
s = s.lstrip()
# Only remove a leading role block if it literally starts the text
for prefix in ("user", "assistant", "User", "Assistant"):
if s.startswith(prefix):
idx = s.find("\n\n")
if idx != -1:
return s[idx+2:].lstrip()
idx = s.find("\n")
if idx != -1:
return s[idx+1:].lstrip()
return s
def generate_full_report(image_path: str, query: str, *,
max_new_tokens: int = 1600,
do_stream: bool = False,
temperature: float = 0.0) -> str:
"""
Generate ECG interpretation report
Args:
image_path: local path to ECG image
query: instruction string for the model
max_new_tokens: maximum tokens to generate
do_stream: whether to stream output (for terminal use)
temperature: sampling temperature (0.0 = greedy)
Returns:
Full decoded interpretation report
"""
image = Image.open(image_path).convert("RGB")
# Build single user turn: image + text
messages = [
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": query}
]}
]
# Create prompt compatible with processor
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=input_text, images=image, return_tensors="pt")
# Move inputs to same device as model
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Setup streamer if requested
streamer = TextStreamer(processor.tokenizer, skip_prompt=True) if do_stream else None
# Generate
with torch.no_grad():
out_ids = model.generate(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
use_cache=True,
do_sample=False if temperature == 0.0 else True,
temperature=temperature,
top_p=1.0,
)
# Decode full generated text
full_raw = processor.batch_decode(out_ids, skip_special_tokens=True)[0]
full_clean = _strip_assistant_prefix_safe(full_raw)
return full_clean
def translate_to_farsi(english_text: str, max_new_tokens: int = 1600) -> str:
"""Translate English text to Persian using the same model"""
msgs = [
{"role": "user", "content": [
{"type": "text",
"text": "فقط متن زیر را به فارسی روان ترجمه کن و فقط ترجمه را برگردان:\n\n" + english_text}
]}
]
prompt = processor.apply_chat_template(msgs, add_generation_prompt=True)
inputs = processor(text=prompt, return_tensors="pt")
# Move to device
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=0.0,
top_p=1.0
)
ans = processor.batch_decode(out, skip_special_tokens=True)[0]
return _strip_assistant_prefix_safe(ans)
# Gradio interface function
def analyze_ecg_gradio(image, text_instruction="", language="Farsi"):
"""
Main function for Gradio interface
Args:
image: uploaded ECG image filepath (string path)
text_instruction: optional clinical note / context
language: output language (English or Farsi)
Returns:
Full AI-generated ECG interpretation report
"""
try:
print(f"Received image: {image}")
print(f"Text instruction: {text_instruction}")
print(f"Language: {language}")
# Build query
query = "You are an expert cardiologist. "
if text_instruction:
query += f"Patient info: {text_instruction}. "
query += "Write an in-depth diagnosis report from this ECG data, including the final diagnosis."
# Generate report in English
print("Generating report in English...")
report = generate_full_report(image, query, max_new_tokens=1600, do_stream=False)
# Translate to Farsi if requested
if language == "Farsi":
print("Translating to Farsi...")
report = translate_to_farsi(report, max_new_tokens=1600)
print("Report generated successfully!")
return report
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
print(error_msg)
return error_msg
# Create minimal Gradio interface for API backend
demo = gr.Interface(
fn=analyze_ecg_gradio,
inputs=[
gr.Image(type="filepath", label="ECG Image"),
gr.Textbox(lines=2, label="Clinical Note (Optional)"),
gr.Dropdown(choices=["English", "Farsi"], value="Farsi", label="Language"),
],
outputs=gr.Textbox(label="ECG Report", lines=15),
title="ECG AI7 Backend",
description="ECG interpretation API powered by Llama 3.2 11B Vision",
flagging_mode="never",
)
# Launch the app
if __name__ == "__main__":
demo.launch()