cardioai-api / kaggle_retrain_and_deploy.py
hssling's picture
Stabilize HF Space runtime: quantized loading, version pins, adapter revision config
11368a6
# %% [markdown]
# CardioAI Kaggle Notebook: Retrain + Deploy to Hugging Face Space
#
# This script is notebook-friendly (run cell by cell in Kaggle).
# Outcome:
# 1) Fine-tune LoRA adapter on ECG image dataset.
# 2) Push adapter to HF model repo.
# 3) Auto-update HF Space config so app serves the new adapter revision.
# %% Install deps (run once in a Kaggle cell)
# !pip -q install -U "transformers>=4.49.0" "datasets>=2.19.0" "accelerate>=0.34.0" "peft>=0.13.0" "huggingface_hub>=0.26.0" "Pillow>=10.0.0"
# !pip -q install -U "bitsandbytes>=0.46.1"
# # After installing/upgrading bitsandbytes on Kaggle, restart session once, then run all cells.
# %%
import os
import json
import random
from dataclasses import dataclass
from typing import Dict, Any, List
import torch
from datasets import load_dataset
from huggingface_hub import HfApi
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
AutoProcessor,
BitsAndBytesConfig,
Qwen2VLForConditionalGeneration,
Trainer,
TrainingArguments
)
# %%
# ----------------------------
# CONFIG (edit these values)
# ----------------------------
BASE_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
DATASET_ID = "IdaFLab/ECG-Plot-Images" # Suitable ECG plot dataset used in your current pipeline
DATASET_SPLIT = "train[:3000]" # Raise when stable (e.g. full train)
HF_ADAPTER_REPO = "hssling/cardioai-adapter"
HF_SPACE_REPO = "hssling/cardioai-api" # Space repo to auto-point to newest adapter revision
OUTPUT_DIR = "/kaggle/working/cardioai_adapter"
SEED = 42
EPOCHS = 2
LR = 2e-4
TRAIN_BATCH_SIZE = 2
GRAD_ACCUM = 4
MAX_TOKENS = 768
LOAD_IN_4BIT = True
# %%
# ----------------------------
# Auth from Kaggle Secrets
# ----------------------------
try:
from kaggle_secrets import UserSecretsClient
_secrets = UserSecretsClient()
HF_TOKEN = _secrets.get_secret("HF_TOKEN")
except Exception as e:
raise RuntimeError("Missing Kaggle secret HF_TOKEN") from e
os.environ["HF_TOKEN"] = HF_TOKEN
api = HfApi(token=HF_TOKEN)
random.seed(SEED)
torch.manual_seed(SEED)
print("Authenticated to Hugging Face Hub.")
# %%
def has_compatible_bitsandbytes() -> bool:
try:
import bitsandbytes as bnb # type: ignore
ver = getattr(bnb, "__version__", "0.0.0")
major, minor, patch = [int(x) for x in ver.split(".")[:3]]
return (major, minor, patch) >= (0, 46, 1)
except Exception:
return False
# %%
# ----------------------------
# Load processor + base model
# ----------------------------
processor = AutoProcessor.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN, use_fast=False)
use_4bit_now = LOAD_IN_4BIT and torch.cuda.is_available() and has_compatible_bitsandbytes()
if use_4bit_now:
print("Using 4-bit quantization with bitsandbytes.")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
else:
print("bitsandbytes>=0.46.1 not available (or no CUDA). Falling back to fp16/bf16 load.")
bnb_config = None
model_kwargs = {
"pretrained_model_name_or_path": BASE_MODEL_ID,
"device_map": "auto",
"token": HF_TOKEN
}
if use_4bit_now:
model_kwargs["quantization_config"] = bnb_config
model_kwargs["torch_dtype"] = torch.float16
else:
model_kwargs["torch_dtype"] = torch.bfloat16 if torch.cuda.is_available() else torch.float32
model = Qwen2VLForConditionalGeneration.from_pretrained(**model_kwargs)
if use_4bit_now:
model = prepare_model_for_kbit_training(model)
lora_cfg = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"]
)
model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()
# %%
# ----------------------------
# Dataset and formatting
# ----------------------------
dataset = load_dataset(DATASET_ID, split=DATASET_SPLIT)
dataset = dataset.shuffle(seed=SEED)
label_map = {
0: "Normal sinus rhythm with no significant ectopy.",
1: "Supraventricular ectopic activity is present.",
2: "Ventricular ectopic beats are present.",
3: "Fusion beat pattern is present."
}
def to_train_example(ex: Dict[str, Any]) -> Dict[str, Any]:
# Keep mapping stable with your existing dataset schema.
finding = label_map.get(int(ex.get("type", 0)), "ECG abnormality present; clinical correlation advised.")
messages = [
{
"role": "system",
"content": "You are CardioAI, an expert cardiology assistant for ECG interpretation."
},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "Analyze this ECG and provide rhythm, key abnormalities, and a short impression."}
]
},
{
"role": "assistant",
"content": [{"type": "text", "text": finding}]
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
return {"image": ex["image"], "text": text}
train_ds = dataset.map(to_train_example, remove_columns=dataset.column_names)
@dataclass
class ECGCollator:
processor: Any
max_tokens: int = MAX_TOKENS
def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
images = [x["image"].convert("RGB") for x in batch]
texts = [x["text"] for x in batch]
model_inputs = self.processor(
text=texts,
images=images,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_tokens
)
labels = model_inputs["input_ids"].clone()
# Ignore padding in loss
labels[labels == self.processor.tokenizer.pad_token_id] = -100
model_inputs["labels"] = labels
return model_inputs
collator = ECGCollator(processor=processor)
# %%
# ----------------------------
# Train
# ----------------------------
args = TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=TRAIN_BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=LR,
num_train_epochs=EPOCHS,
logging_steps=20,
save_strategy="epoch",
fp16=True,
remove_unused_columns=False,
report_to="none",
optim="paged_adamw_8bit"
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_ds,
data_collator=collator
)
trainer.train()
# %%
# ----------------------------
# Save + Push adapter
# ----------------------------
os.makedirs(OUTPUT_DIR, exist_ok=True)
trainer.model.save_pretrained(OUTPUT_DIR) # PEFT adapter files
processor.save_pretrained(OUTPUT_DIR)
api.create_repo(HF_ADAPTER_REPO, repo_type="model", exist_ok=True)
commit_info = api.upload_folder(
folder_path=OUTPUT_DIR,
repo_id=HF_ADAPTER_REPO,
repo_type="model",
commit_message="Kaggle retrain: refresh ECG LoRA adapter"
)
if hasattr(commit_info, "oid"):
adapter_revision = commit_info.oid
else:
adapter_revision = "main"
print(f"Adapter pushed: https://huggingface.co/{HF_ADAPTER_REPO}")
print(f"Adapter revision: {adapter_revision}")
# %%
# ----------------------------
# Update Space runtime config
# ----------------------------
space_cfg = {
"base_model": BASE_MODEL_ID,
"adapter_repo": HF_ADAPTER_REPO,
"adapter_revision": adapter_revision
}
api.upload_file(
path_or_fileobj=json.dumps(space_cfg, indent=2).encode("utf-8"),
path_in_repo="model_config.json",
repo_id=HF_SPACE_REPO,
repo_type="space",
commit_message=f"Point space to adapter revision {adapter_revision}"
)
try:
api.restart_space(repo_id=HF_SPACE_REPO)
print("Space restart requested.")
except Exception as e:
print(f"Space restart API call failed (manual restart may be needed): {e}")
print(f"Space URL: https://huggingface.co/spaces/{HF_SPACE_REPO}")
print("Done. Your app can continue using the same Space endpoint.")