File size: 4,328 Bytes
8af47a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
import os
from huggingface_hub import login

# 1. Configuration targeting ECG Image Scans
MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" 
DATASET_ID = "hssling/ECG-10k-Control"
OUTPUT_DIR = "./cardioai-adapter"
HF_HUB_REPO = "hssling/cardioai-adapter" 

def main():
    # Attempt to authenticate with Hugging Face via Kaggle Secrets
    try:
        from kaggle_secrets import UserSecretsClient
        user_secrets = UserSecretsClient()
        hf_token = user_secrets.get_secret("HF_TOKEN")
        login(token=hf_token)
        print("Successfully logged into Hugging Face Hub using Kaggle Secrets.")
    except Exception as e:
        print("Could not log in via Kaggle Secrets.", e)

    print(f"Loading processor and model: {MODEL_ID}")
    
    processor = AutoProcessor.from_pretrained(MODEL_ID)
    
    # 4-bit Quantization
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        MODEL_ID,
        device_map="auto",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
    )

    print("Applying LoRA parameters...")
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], 
        bias="none",
    )
    model = get_peft_model(model, lora_config)
    
    print(f"Loading dataset: {DATASET_ID}")
    dataset = load_dataset(DATASET_ID, split="train") # Using the full 10k ECG dataset
    
    def format_data(example):
        findings = example.get("findings") or example.get("text") or example.get("description") or "ECG tracing findings."
        messages = [
            {
                "role": "system",
                "content": "You are CardioAI, a highly advanced expert Cardiologist. Analyze the provided Electrocardiogram (ECG/EKG)."
            },
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": "Analyze this 12-lead Electrocardiogram trace and extract the detailed clinical rhythms and pathological findings in a structured format."}
                ]
            },
            {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": str(findings)}
                ]
            }
        ]
        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
        return {"text": text, "image": example["image"]}
    
    formatted_dataset = dataset.map(format_data, remove_columns=dataset.column_names)

    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        learning_rate=2e-4,
        logging_steps=50,
        num_train_epochs=3, # Train extensively across the entire 10k dataset 3 times
        save_strategy="epoch", # Save at the end of every epoch
        fp16=True,
        optim="paged_adamw_8bit",
        remove_unused_columns=False,
        report_to="none"
    )

    def collate_fn(examples):
        texts = [ex["text"] for ex in examples]
        images = [ex["image"] for ex in examples]
        batch = processor(
            text=texts,
            images=images,
            padding=True,
            return_tensors="pt"
        )
        batch["labels"] = batch["input_ids"].clone()
        return batch

    print("Starting fine-tuning...")
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=formatted_dataset,
        data_collator=collate_fn
    )

    trainer.train()
    
    print(f"Saving fine-tuned adapter to {OUTPUT_DIR}")
    trainer.save_model(OUTPUT_DIR)
    processor.save_pretrained(OUTPUT_DIR)
    
    print(f"Pushing model weights to Hugging Face Hub: {HF_HUB_REPO}...")
    try:
        trainer.model.push_to_hub(HF_HUB_REPO)
        processor.push_to_hub(HF_HUB_REPO)
        print(f"✅ Success! Your model is now live at: https://huggingface.co/{HF_HUB_REPO}")
    except Exception as e:
        print(f"❌ Failed to push to Hugging Face Hub. Error: {e}")

if __name__ == "__main__":
    main()