Spaces:
Sleeping
Sleeping
| """One-shot LoRA training on HF Space A100, then push adapter to Hub.""" | |
| import os | |
| os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_cache" | |
| os.environ["USER"] = "appuser" | |
| import gc | |
| import json | |
| import torch | |
| from pathlib import Path | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import LoraConfig | |
| from trl import SFTTrainer, SFTConfig | |
| from datasets import Dataset | |
| from jinja2 import Template | |
| print("=" * 60) | |
| print("CLARKE LoRA TRAINING - Starting") | |
| print("=" * 60) | |
| print(f"GPU: {torch.cuda.get_device_name(0)}") | |
| print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") | |
| MODEL_ID = "google/medgemma-27b-text-it" | |
| ADAPTER_REPO = "yashvshetty/clarke-medgemma-27b-lora" | |
| template_text = Path("backend/prompts/document_generation.j2").read_text() | |
| TEMPLATE = Template(template_text) | |
| train_path = Path("data/training/train.jsonl") | |
| records = [json.loads(line) for line in train_path.read_text().splitlines() if line.strip()] | |
| print(f"Loaded {len(records)} training records") | |
| def format_example(record): | |
| context_json = json.dumps(record["context"], ensure_ascii=False, indent=2) | |
| demo = record["context"]["demographics"] | |
| prompt = TEMPLATE.render( | |
| letter_date="18 Feb 2026", | |
| clinician_name="Dr Sarah Chen", | |
| clinician_title="Consultant, General Practice", | |
| gp_name="Dr Andrew Wilson", | |
| gp_address="Riverside Medical Practice", | |
| patient_name=demo["name"], | |
| patient_dob=demo.get("dob", ""), | |
| patient_nhs=demo.get("nhs_number", ""), | |
| transcript=record["transcript"], | |
| context_json=context_json, | |
| ) | |
| return prompt + "\n" + record["reference_letter"].strip() | |
| texts = [format_example(r) for r in records] | |
| train_dataset = Dataset.from_dict({"text": texts}) | |
| print(f"Dataset: {len(train_dataset)} examples") | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("Loading model in 4-bit...") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| print(f"Model loaded. GPU memory: {torch.cuda.memory_allocated()/1e9:.1f} GB") | |
| peft_config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| bias="none", | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| task_type="CAUSAL_LM", | |
| ) | |
| training_args = SFTConfig( | |
| output_dir="/tmp/clarke-lora-checkpoints", | |
| num_train_epochs=3, | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=8, | |
| learning_rate=2e-4, | |
| logging_steps=1, | |
| save_strategy="no", | |
| report_to=[], | |
| bf16=True, | |
| optim="adamw_8bit", | |
| gradient_checkpointing=True, | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, | |
| train_dataset=train_dataset, | |
| processing_class=tokenizer, | |
| peft_config=peft_config, | |
| args=training_args, | |
| ) | |
| print("Starting training...") | |
| train_result = trainer.train() | |
| loss_history = [entry["loss"] for entry in trainer.state.log_history if "loss" in entry] | |
| print(f"Initial loss: {loss_history[0]:.4f}") | |
| print(f"Final loss: {loss_history[-1]:.4f}") | |
| trainer.model.save_pretrained("/tmp/clarke-lora-adapter") | |
| tokenizer.save_pretrained("/tmp/clarke-lora-adapter") | |
| print("Adapter saved locally") | |
| print(f"Pushing adapter to {ADAPTER_REPO}...") | |
| trainer.model.push_to_hub(ADAPTER_REPO, commit_message="Updated LoRA: new section structure Feb 2026") | |
| tokenizer.push_to_hub(ADAPTER_REPO, commit_message="Updated tokenizer Feb 2026") | |
| print(f"Adapter pushed to {ADAPTER_REPO}") | |
| metrics = { | |
| "initial_loss": float(loss_history[0]), | |
| "final_loss": float(loss_history[-1]), | |
| "epochs": 3, | |
| "lora_rank": 16, | |
| "samples": len(records), | |
| } | |
| print(f"TRAINING COMPLETE. Metrics: {json.dumps(metrics)}") | |
| del model, trainer | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| print("Memory freed.") | |