cardioai-api / train_ecg.py
hssling's picture
Switch to IdaFLab/ECG-Plot-Images Parquet dataset to prevent HTTP bottlenecks
f0e7aa1
import torch
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
import os
from huggingface_hub import login
# 1. Configuration targeting ECG Image Scans
MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
DATASET_ID = "IdaFLab/ECG-Plot-Images"
OUTPUT_DIR = "./cardioai-adapter"
HF_HUB_REPO = "hssling/cardioai-adapter"
def main():
# Attempt to authenticate with Hugging Face via Kaggle Secrets
try:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("HF_TOKEN")
login(token=hf_token)
print("Successfully logged into Hugging Face Hub using Kaggle Secrets.")
except Exception as e:
print("Could not log in via Kaggle Secrets.", e)
print(f"Loading processor and model: {MODEL_ID}")
processor = AutoProcessor.from_pretrained(MODEL_ID)
# 4-bit Quantization
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
print("Applying LoRA parameters...")
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
bias="none",
)
model = get_peft_model(model, lora_config)
print(f"Loading dataset: {DATASET_ID}")
try:
# Load high-quality synthetic/real ECG plots in Parquet format to prevent HTTP bottleneck
dataset = load_dataset(DATASET_ID, split="train[:2000]")
except Exception as e:
print(f"Warning: {DATASET_ID} not found. Synthesizing a robust mock dataset for algorithmic testing.")
from datasets import Dataset
from PIL import Image
# Create solid color dummy images to stand in for ECGs during dry-run testing
dummy_images = [Image.new("RGB", (224, 224), color=(0, 255, 0)) for _ in range(50)]
dummy_findings = ["Normal Sinus Rhythm", "Atrial Fibrillation with RVR", "Acute Anterior MI", "Left Bundle Branch Block", "Sinus Tachycardia"] * 10
dataset = Dataset.from_dict({"image": dummy_images, "findings": dummy_findings})
def format_data(example):
label_map = {
0: "Normal Sinus Rhythm. No significant ectopic activity.",
1: "Supraventricular Ectopic Beat (SVEB). Premature atrial or junctional contraction.",
2: "Ventricular Ectopic Beat (VEB). Premature ventricular contraction.",
3: "Fusion of ventricular and normal beat."
}
# In IdaFLab/ECG-Plot-Images, label is stored in 'type'
lbl = example.get("type", 0)
findings = label_map.get(lbl, "Standard clinical ECG tracing.")
messages = [
{
"role": "system",
"content": "You are CardioAI, a highly advanced expert Cardiologist. Analyze the provided Electrocardiogram (ECG/EKG)."
},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "Analyze this 12-lead Electrocardiogram trace and extract the detailed clinical rhythms and pathological findings in a structured format."}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": str(findings)}
]
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
return {"text": text, "image": example["image"]}
formatted_dataset = dataset.map(format_data, remove_columns=dataset.column_names)
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-4,
logging_steps=50,
num_train_epochs=3, # Train extensively across the entire 10k dataset 3 times
save_strategy="epoch", # Save at the end of every epoch
fp16=True,
optim="paged_adamw_8bit",
remove_unused_columns=False,
report_to="none"
)
def collate_fn(examples):
texts = [ex["text"] for ex in examples]
images = [ex["image"] for ex in examples]
batch = processor(
text=texts,
images=images,
padding=True,
return_tensors="pt"
)
batch["labels"] = batch["input_ids"].clone()
return batch
print("Starting fine-tuning...")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=formatted_dataset,
data_collator=collate_fn
)
trainer.train()
print(f"Saving fine-tuned adapter to {OUTPUT_DIR}")
trainer.save_model(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)
print(f"Pushing model weights to Hugging Face Hub: {HF_HUB_REPO}...")
try:
trainer.model.push_to_hub(HF_HUB_REPO)
processor.push_to_hub(HF_HUB_REPO)
print(f"✅ Success! Your model is now live at: https://huggingface.co/{HF_HUB_REPO}")
except Exception as e:
print(f"❌ Failed to push to Hugging Face Hub. Error: {e}")
if __name__ == "__main__":
main()