AKSazgar's picture
Simplify to minimal backend API
3f53e8d
raw
history blame
5.96 kB
#!/usr/bin/env python3
"""
ECG AI7 - ECG Interpretation using Llama 3.2 11B Vision
Gradio interface for Hugging Face Spaces
"""
import torch
from transformers import MllamaForConditionalGeneration, AutoProcessor, TextStreamer
from PIL import Image
import gradio as gr
# Model configuration
MODEL_ID = "convaiinnovations/ECG-Instruct-Llama-3.2-11B-Vision"
print(f"Loading model: {MODEL_ID}")
print("This may take a few minutes on first load...")
# Load model and processor
model = MllamaForConditionalGeneration.from_pretrained(
MODEL_ID,
dtype=torch.bfloat16,
device_map="auto",
)
processor = AutoProcessor.from_pretrained(MODEL_ID)
print("Model loaded successfully!")
# Helper functions
def _strip_assistant_prefix_safe(s: str) -> str:
"""Safely strip assistant prefix from generated text"""
s = s.lstrip()
# Only remove a leading role block if it literally starts the text
for prefix in ("user", "assistant", "User", "Assistant"):
if s.startswith(prefix):
idx = s.find("\n\n")
if idx != -1:
return s[idx+2:].lstrip()
idx = s.find("\n")
if idx != -1:
return s[idx+1:].lstrip()
return s
def generate_full_report(image_path: str, query: str, *,
max_new_tokens: int = 1600,
do_stream: bool = False,
temperature: float = 0.0) -> str:
"""
Generate ECG interpretation report
Args:
image_path: local path to ECG image
query: instruction string for the model
max_new_tokens: maximum tokens to generate
do_stream: whether to stream output (for terminal use)
temperature: sampling temperature (0.0 = greedy)
Returns:
Full decoded interpretation report
"""
image = Image.open(image_path).convert("RGB")
# Build single user turn: image + text
messages = [
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": query}
]}
]
# Create prompt compatible with processor
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=input_text, images=image, return_tensors="pt")
# Move inputs to same device as model
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Setup streamer if requested
streamer = TextStreamer(processor.tokenizer, skip_prompt=True) if do_stream else None
# Generate
with torch.no_grad():
out_ids = model.generate(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
use_cache=True,
do_sample=False if temperature == 0.0 else True,
temperature=temperature,
top_p=1.0,
)
# Decode full generated text
full_raw = processor.batch_decode(out_ids, skip_special_tokens=True)[0]
full_clean = _strip_assistant_prefix_safe(full_raw)
return full_clean
def translate_to_farsi(english_text: str, max_new_tokens: int = 1600) -> str:
"""Translate English text to Persian using the same model"""
msgs = [
{"role": "user", "content": [
{"type": "text",
"text": "فقط متن زیر را به فارسی روان ترجمه کن و فقط ترجمه را برگردان:\n\n" + english_text}
]}
]
prompt = processor.apply_chat_template(msgs, add_generation_prompt=True)
inputs = processor(text=prompt, return_tensors="pt")
# Move to device
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=0.0,
top_p=1.0
)
ans = processor.batch_decode(out, skip_special_tokens=True)[0]
return _strip_assistant_prefix_safe(ans)
# Gradio interface function
def analyze_ecg_gradio(image, text_instruction="", language="Farsi"):
"""
Main function for Gradio interface
Args:
image: uploaded ECG image filepath (string path)
text_instruction: optional clinical note / context
language: output language (English or Farsi)
Returns:
Full AI-generated ECG interpretation report
"""
try:
print(f"Received image: {image}")
print(f"Text instruction: {text_instruction}")
print(f"Language: {language}")
# Build query
query = "You are an expert cardiologist. "
if text_instruction:
query += f"Patient info: {text_instruction}. "
query += "Write an in-depth diagnosis report from this ECG data, including the final diagnosis."
# Generate report in English
print("Generating report in English...")
report = generate_full_report(image, query, max_new_tokens=1600, do_stream=False)
# Translate to Farsi if requested
if language == "Farsi":
print("Translating to Farsi...")
report = translate_to_farsi(report, max_new_tokens=1600)
print("Report generated successfully!")
return report
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
print(error_msg)
return error_msg
# Create minimal Gradio interface for API backend
demo = gr.Interface(
fn=analyze_ecg_gradio,
inputs=[
gr.Image(type="filepath", label="ECG Image"),
gr.Textbox(lines=2, label="Clinical Note (Optional)"),
gr.Dropdown(choices=["English", "Farsi"], value="Farsi", label="Language"),
],
outputs=gr.Textbox(label="ECG Report", lines=15),
title="ECG AI7 Backend",
description="ECG interpretation API powered by Llama 3.2 11B Vision",
flagging_mode="never",
)
# Launch the app
if __name__ == "__main__":
demo.launch()