yashvshetty commited on
Commit
7b3c958
·
1 Parent(s): 1014cbc

Add optional LoRA training on startup (RUN_LORA_TRAINING flag)

Browse files
Files changed (3) hide show
  1. requirements.txt +6 -0
  2. scripts/start.sh +13 -5
  3. scripts/train_lora.py +134 -0
requirements.txt CHANGED
@@ -20,3 +20,9 @@ pydantic>=2.10.6
20
  pydantic-settings>=2.7.1
21
  scipy>=1.12.0
22
  soundfile>=0.12.1
 
 
 
 
 
 
 
20
  pydantic-settings>=2.7.1
21
  scipy>=1.12.0
22
  soundfile>=0.12.1
23
+ peft>=0.7.0
24
+ trl>=0.7.0
25
+ datasets>=2.14.0
26
+ peft>=0.7.0
27
+ trl>=0.7.0
28
+ datasets>=2.14.0
scripts/start.sh CHANGED
@@ -2,16 +2,24 @@
2
  set -e
3
 
4
  echo "Starting Clarke..."
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  echo "USE_MOCK_FHIR=${USE_MOCK_FHIR:-false}"
6
  echo "MEDASR_MODEL_ID=${MEDASR_MODEL_ID:-not set}"
7
 
8
- # Start mock FHIR server in background so EHR agent has patient data to query.
9
- # Runs on port 8080 (internal only); main app connects via localhost.
10
  python -m backend.fhir.mock_api &
11
  FHIR_PID=$!
12
  echo "Mock FHIR server started (PID: ${FHIR_PID})"
13
-
14
- # Brief pause to let FHIR server bind to port before main app starts querying it
15
  sleep 2
16
-
17
  python app.py
 
2
  set -e
3
 
4
  echo "Starting Clarke..."
5
+
6
+ # Run LoRA training ONLY if flag is set
7
+ # Wrapped so failure NEVER prevents app startup
8
+ if [ "${RUN_LORA_TRAINING}" = "true" ]; then
9
+ echo "============================================"
10
+ echo "LoRA training requested. Running..."
11
+ echo "============================================"
12
+ python scripts/train_lora.py || echo "WARNING: Training failed but app will start normally"
13
+ echo "============================================"
14
+ echo "Training phase complete. Starting app..."
15
+ echo "============================================"
16
+ fi
17
+
18
  echo "USE_MOCK_FHIR=${USE_MOCK_FHIR:-false}"
19
  echo "MEDASR_MODEL_ID=${MEDASR_MODEL_ID:-not set}"
20
 
 
 
21
  python -m backend.fhir.mock_api &
22
  FHIR_PID=$!
23
  echo "Mock FHIR server started (PID: ${FHIR_PID})"
 
 
24
  sleep 2
 
25
  python app.py
scripts/train_lora.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """One-shot LoRA training on HF Space A100, then push adapter to Hub."""
2
+ import os
3
+ import gc
4
+ import json
5
+ import torch
6
+ from pathlib import Path
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
8
+ from peft import LoraConfig
9
+ from trl import SFTTrainer
10
+ from datasets import Dataset
11
+ from jinja2 import Template
12
+
13
+ print("=" * 60)
14
+ print("CLARKE LoRA TRAINING - Starting")
15
+ print("=" * 60)
16
+
17
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
18
+ print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
19
+
20
+ MODEL_ID = "google/medgemma-27b-text-it"
21
+ ADAPTER_REPO = "yashvshetty/clarke-medgemma-27b-lora"
22
+
23
+ template_text = Path("backend/prompts/document_generation.j2").read_text()
24
+ TEMPLATE = Template(template_text)
25
+
26
+ train_path = Path("data/training/train.jsonl")
27
+ records = [json.loads(line) for line in train_path.read_text().splitlines() if line.strip()]
28
+ print(f"Loaded {len(records)} training records")
29
+
30
+
31
+ def format_example(record):
32
+ context_json = json.dumps(record["context"], ensure_ascii=False, indent=2)
33
+ demo = record["context"]["demographics"]
34
+ prompt = TEMPLATE.render(
35
+ letter_date="18 Feb 2026",
36
+ clinician_name="Dr Sarah Chen",
37
+ clinician_title="Consultant, General Practice",
38
+ gp_name="Dr Andrew Wilson",
39
+ gp_address="Riverside Medical Practice",
40
+ patient_name=demo["name"],
41
+ patient_dob=demo.get("dob", ""),
42
+ patient_nhs=demo.get("nhs_number", ""),
43
+ transcript=record["transcript"],
44
+ context_json=context_json,
45
+ )
46
+ return prompt + "\n" + record["reference_letter"].strip()
47
+
48
+
49
+ texts = [format_example(r) for r in records]
50
+ train_dataset = Dataset.from_dict({"text": texts})
51
+ print(f"Dataset: {len(train_dataset)} examples")
52
+
53
+ print("Loading tokenizer...")
54
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
55
+ if tokenizer.pad_token is None:
56
+ tokenizer.pad_token = tokenizer.eos_token
57
+
58
+ print("Loading model in 4-bit...")
59
+ bnb_config = BitsAndBytesConfig(
60
+ load_in_4bit=True,
61
+ bnb_4bit_quant_type="nf4",
62
+ bnb_4bit_compute_dtype=torch.bfloat16,
63
+ bnb_4bit_use_double_quant=True,
64
+ )
65
+ model = AutoModelForCausalLM.from_pretrained(
66
+ MODEL_ID,
67
+ quantization_config=bnb_config,
68
+ device_map="auto",
69
+ torch_dtype=torch.bfloat16,
70
+ )
71
+ print(f"Model loaded. GPU memory: {torch.cuda.memory_allocated()/1e9:.1f} GB")
72
+
73
+ peft_config = LoraConfig(
74
+ r=16,
75
+ lora_alpha=32,
76
+ lora_dropout=0.05,
77
+ bias="none",
78
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
79
+ task_type="CAUSAL_LM",
80
+ )
81
+
82
+ training_args = TrainingArguments(
83
+ output_dir="/tmp/clarke-lora-checkpoints",
84
+ num_train_epochs=3,
85
+ per_device_train_batch_size=1,
86
+ gradient_accumulation_steps=8,
87
+ learning_rate=2e-4,
88
+ logging_steps=1,
89
+ save_strategy="no",
90
+ report_to=[],
91
+ bf16=True,
92
+ optim="adamw_8bit",
93
+ gradient_checkpointing=True,
94
+ )
95
+
96
+ trainer = SFTTrainer(
97
+ model=model,
98
+ train_dataset=train_dataset,
99
+ tokenizer=tokenizer,
100
+ peft_config=peft_config,
101
+ dataset_text_field="text",
102
+ max_seq_length=2048,
103
+ args=training_args,
104
+ )
105
+
106
+ print("Starting training...")
107
+ train_result = trainer.train()
108
+
109
+ loss_history = [entry["loss"] for entry in trainer.state.log_history if "loss" in entry]
110
+ print(f"Initial loss: {loss_history[0]:.4f}")
111
+ print(f"Final loss: {loss_history[-1]:.4f}")
112
+
113
+ trainer.model.save_pretrained("/tmp/clarke-lora-adapter")
114
+ tokenizer.save_pretrained("/tmp/clarke-lora-adapter")
115
+ print("Adapter saved locally")
116
+
117
+ print(f"Pushing adapter to {ADAPTER_REPO}...")
118
+ trainer.model.push_to_hub(ADAPTER_REPO, commit_message="Updated LoRA: new section structure Feb 2026")
119
+ tokenizer.push_to_hub(ADAPTER_REPO, commit_message="Updated tokenizer Feb 2026")
120
+ print(f"Adapter pushed to {ADAPTER_REPO}")
121
+
122
+ metrics = {
123
+ "initial_loss": float(loss_history[0]),
124
+ "final_loss": float(loss_history[-1]),
125
+ "epochs": 3,
126
+ "lora_rank": 16,
127
+ "samples": len(records),
128
+ }
129
+ print(f"TRAINING COMPLETE. Metrics: {json.dumps(metrics)}")
130
+
131
+ del model, trainer
132
+ gc.collect()
133
+ torch.cuda.empty_cache()
134
+ print("Memory freed.")