hssling commited on
Commit
8af47a3
·
0 Parent(s):

Initialize CardioAI training and deployment pipeline

Browse files
Files changed (5) hide show
  1. .github/workflows/sync_to_hub.yml +23 -0
  2. README.md +15 -0
  3. app.py +81 -0
  4. requirements.txt +8 -0
  5. train_ecg.py +125 -0
.github/workflows/sync_to_hub.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face Hub
2
+
3
+ on:
4
+ push:
5
+ branches: [main, master]
6
+
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ sync-to-hub:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v3
14
+ with:
15
+ fetch-depth: 0
16
+ lfs: true
17
+
18
+ - name: Push to Hugging Face Hub
19
+ env:
20
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
21
+ run: |
22
+ git remote add space https://hssling:$HF_TOKEN@huggingface.co/spaces/hssling/cardioai-api
23
+ git push --force space master:main
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CardioAI ECG API
3
+ emoji: ❤️‍🔥
4
+ colorFrom: red
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: "4.26.0"
8
+ app_file: app.py
9
+ pinned: false
10
+ python_version: "3.10"
11
+ ---
12
+
13
+ # CardioAI Fine-Tuned Model API
14
+
15
+ Training logic and execution backend for Kaggle-to-HuggingFace continuous deployment.
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
4
+ from PIL import Image
5
+ import json
6
+
7
+ MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
8
+ ADAPTER_ID = "hssling/cardioai-adapter"
9
+
10
+ print("Starting App Engine...")
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
13
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
14
+ MODEL_ID,
15
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
16
+ device_map="auto"
17
+ )
18
+
19
+ if ADAPTER_ID:
20
+ print(f"Loading custom fine-tuned LoRA weights: {ADAPTER_ID}")
21
+ try:
22
+ model.load_adapter(ADAPTER_ID)
23
+ except Exception as e:
24
+ print(f"Failed to load adapter. Using base model. Error: {e}")
25
+
26
+ def diagnose_ecg(image: Image.Image = None, temp: float = 0.2, max_tokens: int = 1500):
27
+ try:
28
+ if image is None:
29
+ return json.dumps({"error": "No image provided."})
30
+
31
+ system_prompt = "You are CardioAI, a highly advanced expert Cardiologist. Analyze the provided Electrocardiogram (ECG/EKG)."
32
+ user_prompt = "Analyze this 12-lead Electrocardiogram trace and extract the detailed clinical rhythms and pathological findings in a structured format."
33
+
34
+ messages = [
35
+ {"role": "system", "content": system_prompt},
36
+ {
37
+ "role": "user",
38
+ "content": [
39
+ {"type": "image"},
40
+ {"type": "text", "text": user_prompt}
41
+ ]
42
+ }
43
+ ]
44
+
45
+ text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
46
+
47
+ inputs = processor(
48
+ text=[text_input],
49
+ images=[image],
50
+ padding=True,
51
+ return_tensors="pt"
52
+ ).to(device)
53
+
54
+ with torch.no_grad():
55
+ generated_ids = model.generate(**inputs, max_new_tokens=int(max_tokens), temperature=float(temp), top_p=0.9, do_sample=True)
56
+
57
+ generated_ids_trimmed = [
58
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
59
+ ]
60
+
61
+ output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
62
+
63
+ return output_text
64
+
65
+ except Exception as e:
66
+ return f"Error: {str(e)}"
67
+
68
+ demo = gr.Interface(
69
+ fn=diagnose_ecg,
70
+ inputs=[
71
+ gr.Image(type="pil", label="ECG Image Scan"),
72
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, label="Temperature"),
73
+ gr.Slider(minimum=256, maximum=4096, value=1500, step=256, label="Max Tokens")
74
+ ],
75
+ outputs=gr.Markdown(label="Clinical Report Output"),
76
+ title="CardioAI Inference API",
77
+ description="Fine-tuned Medical LLM for Electrocardiogram (ECG) Tracings."
78
+ )
79
+
80
+ if __name__ == "__main__":
81
+ demo.launch(share=False)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0
2
+ transformers>=4.40.0
3
+ accelerate
4
+ peft
5
+ bitsandbytes
6
+ datasets
7
+ gradio>=4.0.0
8
+ Pillow
train_ecg.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, TrainingArguments, Trainer
3
+ from peft import LoraConfig, get_peft_model
4
+ from datasets import load_dataset
5
+ import os
6
+ from huggingface_hub import login
7
+
8
+ # 1. Configuration targeting ECG Image Scans
9
+ MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
10
+ DATASET_ID = "hssling/ECG-10k-Control"
11
+ OUTPUT_DIR = "./cardioai-adapter"
12
+ HF_HUB_REPO = "hssling/cardioai-adapter"
13
+
14
+ def main():
15
+ # Attempt to authenticate with Hugging Face via Kaggle Secrets
16
+ try:
17
+ from kaggle_secrets import UserSecretsClient
18
+ user_secrets = UserSecretsClient()
19
+ hf_token = user_secrets.get_secret("HF_TOKEN")
20
+ login(token=hf_token)
21
+ print("Successfully logged into Hugging Face Hub using Kaggle Secrets.")
22
+ except Exception as e:
23
+ print("Could not log in via Kaggle Secrets.", e)
24
+
25
+ print(f"Loading processor and model: {MODEL_ID}")
26
+
27
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
28
+
29
+ # 4-bit Quantization
30
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
31
+ MODEL_ID,
32
+ device_map="auto",
33
+ torch_dtype=torch.float16,
34
+ low_cpu_mem_usage=True,
35
+ )
36
+
37
+ print("Applying LoRA parameters...")
38
+ lora_config = LoraConfig(
39
+ r=16,
40
+ lora_alpha=32,
41
+ lora_dropout=0.05,
42
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
43
+ bias="none",
44
+ )
45
+ model = get_peft_model(model, lora_config)
46
+
47
+ print(f"Loading dataset: {DATASET_ID}")
48
+ dataset = load_dataset(DATASET_ID, split="train") # Using the full 10k ECG dataset
49
+
50
+ def format_data(example):
51
+ findings = example.get("findings") or example.get("text") or example.get("description") or "ECG tracing findings."
52
+ messages = [
53
+ {
54
+ "role": "system",
55
+ "content": "You are CardioAI, a highly advanced expert Cardiologist. Analyze the provided Electrocardiogram (ECG/EKG)."
56
+ },
57
+ {
58
+ "role": "user",
59
+ "content": [
60
+ {"type": "image"},
61
+ {"type": "text", "text": "Analyze this 12-lead Electrocardiogram trace and extract the detailed clinical rhythms and pathological findings in a structured format."}
62
+ ]
63
+ },
64
+ {
65
+ "role": "assistant",
66
+ "content": [
67
+ {"type": "text", "text": str(findings)}
68
+ ]
69
+ }
70
+ ]
71
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
72
+ return {"text": text, "image": example["image"]}
73
+
74
+ formatted_dataset = dataset.map(format_data, remove_columns=dataset.column_names)
75
+
76
+ training_args = TrainingArguments(
77
+ output_dir=OUTPUT_DIR,
78
+ per_device_train_batch_size=2,
79
+ gradient_accumulation_steps=4,
80
+ learning_rate=2e-4,
81
+ logging_steps=50,
82
+ num_train_epochs=3, # Train extensively across the entire 10k dataset 3 times
83
+ save_strategy="epoch", # Save at the end of every epoch
84
+ fp16=True,
85
+ optim="paged_adamw_8bit",
86
+ remove_unused_columns=False,
87
+ report_to="none"
88
+ )
89
+
90
+ def collate_fn(examples):
91
+ texts = [ex["text"] for ex in examples]
92
+ images = [ex["image"] for ex in examples]
93
+ batch = processor(
94
+ text=texts,
95
+ images=images,
96
+ padding=True,
97
+ return_tensors="pt"
98
+ )
99
+ batch["labels"] = batch["input_ids"].clone()
100
+ return batch
101
+
102
+ print("Starting fine-tuning...")
103
+ trainer = Trainer(
104
+ model=model,
105
+ args=training_args,
106
+ train_dataset=formatted_dataset,
107
+ data_collator=collate_fn
108
+ )
109
+
110
+ trainer.train()
111
+
112
+ print(f"Saving fine-tuned adapter to {OUTPUT_DIR}")
113
+ trainer.save_model(OUTPUT_DIR)
114
+ processor.save_pretrained(OUTPUT_DIR)
115
+
116
+ print(f"Pushing model weights to Hugging Face Hub: {HF_HUB_REPO}...")
117
+ try:
118
+ trainer.model.push_to_hub(HF_HUB_REPO)
119
+ processor.push_to_hub(HF_HUB_REPO)
120
+ print(f"✅ Success! Your model is now live at: https://huggingface.co/{HF_HUB_REPO}")
121
+ except Exception as e:
122
+ print(f"❌ Failed to push to Hugging Face Hub. Error: {e}")
123
+
124
+ if __name__ == "__main__":
125
+ main()