clarke / scripts /train_lora.py
yashvshetty's picture
Fix: remove unsupported SFTConfig params
9ffb169
"""One-shot LoRA training on HF Space A100, then push adapter to Hub."""
import os
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_cache"
os.environ["USER"] = "appuser"
import gc
import json
import torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from datasets import Dataset
from jinja2 import Template
print("=" * 60)
print("CLARKE LoRA TRAINING - Starting")
print("=" * 60)
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
MODEL_ID = "google/medgemma-27b-text-it"
ADAPTER_REPO = "yashvshetty/clarke-medgemma-27b-lora"
template_text = Path("backend/prompts/document_generation.j2").read_text()
TEMPLATE = Template(template_text)
train_path = Path("data/training/train.jsonl")
records = [json.loads(line) for line in train_path.read_text().splitlines() if line.strip()]
print(f"Loaded {len(records)} training records")
def format_example(record):
context_json = json.dumps(record["context"], ensure_ascii=False, indent=2)
demo = record["context"]["demographics"]
prompt = TEMPLATE.render(
letter_date="18 Feb 2026",
clinician_name="Dr Sarah Chen",
clinician_title="Consultant, General Practice",
gp_name="Dr Andrew Wilson",
gp_address="Riverside Medical Practice",
patient_name=demo["name"],
patient_dob=demo.get("dob", ""),
patient_nhs=demo.get("nhs_number", ""),
transcript=record["transcript"],
context_json=context_json,
)
return prompt + "\n" + record["reference_letter"].strip()
texts = [format_example(r) for r in records]
train_dataset = Dataset.from_dict({"text": texts})
print(f"Dataset: {len(train_dataset)} examples")
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Loading model in 4-bit...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
print(f"Model loaded. GPU memory: {torch.cuda.memory_allocated()/1e9:.1f} GB")
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)
training_args = SFTConfig(
output_dir="/tmp/clarke-lora-checkpoints",
num_train_epochs=3,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
learning_rate=2e-4,
logging_steps=1,
save_strategy="no",
report_to=[],
bf16=True,
optim="adamw_8bit",
gradient_checkpointing=True,
)
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
processing_class=tokenizer,
peft_config=peft_config,
args=training_args,
)
print("Starting training...")
train_result = trainer.train()
loss_history = [entry["loss"] for entry in trainer.state.log_history if "loss" in entry]
print(f"Initial loss: {loss_history[0]:.4f}")
print(f"Final loss: {loss_history[-1]:.4f}")
trainer.model.save_pretrained("/tmp/clarke-lora-adapter")
tokenizer.save_pretrained("/tmp/clarke-lora-adapter")
print("Adapter saved locally")
print(f"Pushing adapter to {ADAPTER_REPO}...")
trainer.model.push_to_hub(ADAPTER_REPO, commit_message="Updated LoRA: new section structure Feb 2026")
tokenizer.push_to_hub(ADAPTER_REPO, commit_message="Updated tokenizer Feb 2026")
print(f"Adapter pushed to {ADAPTER_REPO}")
metrics = {
"initial_loss": float(loss_history[0]),
"final_loss": float(loss_history[-1]),
"epochs": 3,
"lora_rank": 16,
"samples": len(records),
}
print(f"TRAINING COMPLETE. Metrics: {json.dumps(metrics)}")
del model, trainer
gc.collect()
torch.cuda.empty_cache()
print("Memory freed.")