#!/usr/bin/env python3 """ ECG AI7 - ECG Interpretation using Llama 3.2 11B Vision Gradio interface for Hugging Face Spaces """ import torch from transformers import MllamaForConditionalGeneration, AutoProcessor, TextStreamer, BitsAndBytesConfig from PIL import Image import gradio as gr # Model configuration MODEL_ID = "convaiinnovations/ECG-Instruct-Llama-3.2-11B-Vision" print(f"Loading model: {MODEL_ID}") print("Loading in 4-bit mode to fit in free tier memory (16GB)...") print("This may take a few minutes on first load...") # Configure 4-bit quantization properly using BitsAndBytesConfig # This is more stable than deprecated load_in_4bit parameter quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) # Load model and processor with 4-bit quantization to reduce memory significantly # This allows the 11B model to run on free tier (16GB GPU) model = MllamaForConditionalGeneration.from_pretrained( MODEL_ID, quantization_config=quantization_config, device_map="auto", ) processor = AutoProcessor.from_pretrained(MODEL_ID) print("Model loaded successfully in 4-bit mode!") # Helper functions def _strip_assistant_prefix_safe(s: str) -> str: """Safely strip assistant prefix from generated text""" s = s.lstrip() # Only remove a leading role block if it literally starts the text for prefix in ("user", "assistant", "User", "Assistant"): if s.startswith(prefix): idx = s.find("\n\n") if idx != -1: return s[idx+2:].lstrip() idx = s.find("\n") if idx != -1: return s[idx+1:].lstrip() return s def generate_full_report(image_path: str, query: str, *, max_new_tokens: int = 1600, do_stream: bool = False, temperature: float = 0.0) -> str: """ Generate ECG interpretation report Args: image_path: local path to ECG image query: instruction string for the model max_new_tokens: maximum tokens to generate do_stream: whether to stream output (for terminal use) temperature: sampling temperature (0.0 = greedy) Returns: Full decoded interpretation report """ image = Image.open(image_path).convert("RGB") # Build single user turn: image + text messages = [ {"role": "user", "content": [ {"type": "image"}, {"type": "text", "text": query} ]} ] # Create prompt compatible with processor input_text = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=input_text, images=image, return_tensors="pt") # Move inputs to same device as model inputs = {k: v.to(model.device) for k, v in inputs.items()} # Setup streamer if requested streamer = TextStreamer(processor.tokenizer, skip_prompt=True) if do_stream else None # Generate with torch.no_grad(): out_ids = model.generate( **inputs, streamer=streamer, max_new_tokens=max_new_tokens, use_cache=True, do_sample=False if temperature == 0.0 else True, temperature=temperature, top_p=1.0, ) # Decode full generated text full_raw = processor.batch_decode(out_ids, skip_special_tokens=True)[0] full_clean = _strip_assistant_prefix_safe(full_raw) return full_clean def translate_to_farsi(english_text: str, max_new_tokens: int = 1600) -> str: """Translate English text to Persian using the same model""" msgs = [ {"role": "user", "content": [ {"type": "text", "text": "فقط متن زیر را به فارسی روان ترجمه کن و فقط ترجمه را برگردان:\n\n" + english_text} ]} ] prompt = processor.apply_chat_template(msgs, add_generation_prompt=True) inputs = processor(text=prompt, return_tensors="pt") # Move to device inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.0, top_p=1.0 ) ans = processor.batch_decode(out, skip_special_tokens=True)[0] return _strip_assistant_prefix_safe(ans) # Gradio interface function def analyze_ecg_gradio(image, text_instruction="", language="Farsi"): """ Main function for Gradio interface Args: image: uploaded ECG image filepath (string path) text_instruction: optional clinical note / context language: output language (English or Farsi) Returns: Full AI-generated ECG interpretation report """ try: print(f"Received image: {image}") print(f"Text instruction: {text_instruction}") print(f"Language: {language}") # Build query query = "You are an expert cardiologist. " if text_instruction: query += f"Patient info: {text_instruction}. " query += "Write an in-depth diagnosis report from this ECG data, including the final diagnosis." # Generate report in English print("Generating report in English...") report = generate_full_report(image, query, max_new_tokens=1600, do_stream=False) # Translate to Farsi if requested if language == "Farsi": print("Translating to Farsi...") report = translate_to_farsi(report, max_new_tokens=1600) print("Report generated successfully!") return report except Exception as e: import traceback error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" print(error_msg) return error_msg # Create minimal Gradio interface for API backend demo = gr.Interface( fn=analyze_ecg_gradio, inputs=[ gr.Image(type="filepath", label="ECG Image"), gr.Textbox(lines=2, label="Clinical Note (Optional)"), gr.Dropdown(choices=["English", "Farsi"], value="Farsi", label="Language"), ], outputs=gr.Textbox(label="ECG Report", lines=15), title="ECG AI7 Backend", description="ECG interpretation API powered by Llama 3.2 11B Vision", flagging_mode="never", ) # Launch the app if __name__ == "__main__": demo.launch()