revana commited on
Commit
80a00d8
Β·
verified Β·
1 Parent(s): 0df61ec

Upload fingpt/trainer.py

Browse files
Files changed (1) hide show
  1. fingpt/trainer.py +239 -0
fingpt/trainer.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fine-tuning training loop for SFT and LoRA.
2
+
3
+ Supports:
4
+ - Full supervised fine-tuning (SFTConfig)
5
+ - LoRA fine-tuning (LoRAConfig)
6
+ - bfloat16 AMP
7
+ - Gradient accumulation
8
+ - Cosine LR schedule with warmup
9
+ - Periodic eval + checkpoint
10
+ - Weights & Biases logging
11
+ """
12
+
13
+ import math
14
+ import time
15
+ from pathlib import Path
16
+ from typing import Optional
17
+
18
+ import torch
19
+ from torch.amp import autocast
20
+ from torch.utils.data import DataLoader
21
+ from transformers import AutoModelForCausalLM, get_cosine_schedule_with_warmup
22
+
23
+ from .config import LoRAConfig, SFTConfig
24
+ from .data import load_datasets, make_collate_fn
25
+ from .lora import inject_lora, lora_state_dict
26
+
27
+ try:
28
+ import wandb
29
+ except ImportError:
30
+ wandb = None
31
+
32
+
33
+ # ── Helpers ───────────────────────────────────────────────────────────────────
34
+
35
+ def _param_summary(model: torch.nn.Module) -> str:
36
+ total = sum(p.numel() for p in model.parameters())
37
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
38
+ pct = 100 * trainable / max(1, total)
39
+ return f"total={total:,} trainable={trainable:,} ({pct:.2f}%)"
40
+
41
+
42
+ def _evaluate(model, loader, device, cfg, max_batches: int = 20) -> float:
43
+ model.eval()
44
+ total_loss = total_tok = 0
45
+ with torch.no_grad():
46
+ for i, batch in enumerate(loader):
47
+ if i >= max_batches:
48
+ break
49
+ input_ids = batch["input_ids"].to(device)
50
+ labels = batch["labels"].to(device)
51
+ with autocast(device_type=device.type, dtype=torch.bfloat16, enabled=cfg.bf16):
52
+ out = model(input_ids=input_ids, labels=labels)
53
+ n = (labels != -100).sum().item()
54
+ total_loss += out.loss.item() * n
55
+ total_tok += n
56
+ return total_loss / max(1, total_tok)
57
+
58
+
59
+ def _save(cfg, model, step: int, use_lora: bool, final: bool = False) -> None:
60
+ out = Path(cfg.output_dir)
61
+ out.mkdir(parents=True, exist_ok=True)
62
+
63
+ if use_lora:
64
+ # Save only the adapter weights (~50 MB) β€” base model stays on HuggingFace Hub
65
+ state = lora_state_dict(model)
66
+ tag = "adapter_final.pt" if final else f"adapter_step_{step:07d}.pt"
67
+ meta = {"step": step, "mode": "lora", "model_name": cfg.model_name,
68
+ "lora_r": cfg.lora_r, "lora_alpha": cfg.lora_alpha,
69
+ "lora_target_modules": cfg.lora_target_modules}
70
+ else:
71
+ # Full SFT: save the complete model state dict
72
+ raw = model.module if hasattr(model, "module") else model
73
+ state = raw.state_dict()
74
+ tag = "model_final.pt" if final else f"model_step_{step:07d}.pt"
75
+ meta = {"step": step, "mode": "sft", "model_name": cfg.model_name}
76
+
77
+ path = out / tag
78
+ torch.save({"meta": meta, "state_dict": state}, path)
79
+ kind = "adapter" if use_lora else "model"
80
+ print(f"[fingpt] Saved {kind} β†’ {path} ({path.stat().st_size / 1e6:.0f} MB)")
81
+
82
+
83
+ # ── Main training function ────────────────────────────────────────────────────
84
+
85
+ def train(cfg: SFTConfig) -> None:
86
+ torch.manual_seed(cfg.seed)
87
+ if torch.cuda.is_available():
88
+ torch.cuda.manual_seed_all(cfg.seed)
89
+
90
+ use_lora = isinstance(cfg, LoRAConfig)
91
+
92
+ # ── Load tokenizer + datasets ──────────────────────────────────────────────
93
+ train_ds, val_ds, tokenizer = load_datasets(cfg)
94
+
95
+ # ── Load model ─────────────────────────────────────────────────────────────
96
+ t0 = time.time()
97
+ print(f"[fingpt] Loading {cfg.model_name} ...")
98
+ cuda_ok = torch.cuda.is_available()
99
+ try:
100
+ import accelerate # noqa: F401
101
+ load_kwargs = {"device_map": "auto"} if cuda_ok else {}
102
+ except ImportError:
103
+ load_kwargs = {}
104
+ model = AutoModelForCausalLM.from_pretrained(
105
+ cfg.model_name,
106
+ torch_dtype=torch.bfloat16 if cfg.bf16 else torch.float32,
107
+ trust_remote_code=True,
108
+ **load_kwargs,
109
+ )
110
+ # Determine the device the model actually lives on
111
+ if load_kwargs:
112
+ # device_map="auto" β€” infer from the first parameter
113
+ device = next(model.parameters()).device
114
+ else:
115
+ device = torch.device("cuda" if cuda_ok else "cpu")
116
+ model = model.to(device)
117
+ print(f"[fingpt] Model on {device} | loaded in {time.time()-t0:.1f}s | {_param_summary(model)}")
118
+
119
+ # ── Inject LoRA adapters (LoRA mode only) ─────────────────────────────────
120
+ if use_lora:
121
+ model = inject_lora(
122
+ model,
123
+ target_modules=cfg.lora_target_modules,
124
+ r=cfg.lora_r,
125
+ alpha=cfg.lora_alpha,
126
+ dropout=cfg.lora_dropout,
127
+ )
128
+ else:
129
+ # Full SFT: all parameters are trainable
130
+ for p in model.parameters():
131
+ p.requires_grad_(True)
132
+
133
+ print(f"[fingpt] Trainable params | {_param_summary(model)}")
134
+
135
+ # ── DataLoaders ────────────────────────────────────────────────────────────
136
+ pad_id = tokenizer.pad_token_id or 0
137
+ collate = make_collate_fn(pad_id)
138
+ nw = getattr(cfg, "dataloader_workers", 4)
139
+ train_loader = DataLoader(
140
+ train_ds, batch_size=cfg.batch_size, shuffle=True,
141
+ collate_fn=collate, num_workers=nw, pin_memory=cuda_ok,
142
+ )
143
+ val_loader = DataLoader(
144
+ val_ds, batch_size=cfg.batch_size, shuffle=False,
145
+ collate_fn=collate, num_workers=min(nw, 2),
146
+ )
147
+
148
+ # ── Optimizer + LR schedule ────────────────────────────────────────────────
149
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
150
+ optimizer = torch.optim.AdamW(trainable_params, lr=cfg.lr, weight_decay=cfg.weight_decay)
151
+
152
+ total_steps = cfg.max_steps or (
153
+ len(train_loader) * cfg.num_epochs // cfg.grad_accum_steps
154
+ )
155
+ total_steps = max(1, total_steps) # guard: avoid 0 with tiny datasets
156
+ warmup_steps = max(1, int(total_steps * cfg.warmup_ratio))
157
+ scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
158
+
159
+ print(
160
+ f"[fingpt] Training | steps={total_steps:,} warmup={warmup_steps} "
161
+ f"lr={cfg.lr:.1e} bs={cfg.batch_size}Γ—{cfg.grad_accum_steps} "
162
+ f"{'LoRA r=' + str(cfg.lora_r) if use_lora else 'Full SFT'}"
163
+ )
164
+
165
+ # ── Weights & Biases ───────────────────────────────────────────────────────
166
+ run = None
167
+ if cfg.use_wandb and wandb is not None:
168
+ run = wandb.init(
169
+ project=cfg.wandb_project,
170
+ name=cfg.wandb_run_name,
171
+ config=cfg.__dict__,
172
+ )
173
+
174
+ # ── Training loop ──────────────────────────────────────────────────────────
175
+ model.train()
176
+ step = 0
177
+ t_step = time.perf_counter()
178
+ accum_loss = 0.0
179
+
180
+ for epoch in range(cfg.num_epochs):
181
+ for batch in train_loader:
182
+ input_ids = batch["input_ids"].to(device)
183
+ labels = batch["labels"].to(device)
184
+
185
+ with autocast(device_type=device.type, dtype=torch.bfloat16, enabled=cfg.bf16):
186
+ out = model(input_ids=input_ids, labels=labels)
187
+ loss = out.loss / cfg.grad_accum_steps
188
+
189
+ loss.backward()
190
+ accum_loss += loss.item()
191
+
192
+ # ── Optimizer step every grad_accum_steps micro-batches ────────────
193
+ if (step + 1) % cfg.grad_accum_steps == 0:
194
+ torch.nn.utils.clip_grad_norm_(trainable_params, cfg.max_grad_norm)
195
+ optimizer.step()
196
+ scheduler.step()
197
+ optimizer.zero_grad(set_to_none=True)
198
+
199
+ # ── Logging ────────────────────────────────────────────────────────
200
+ if step % cfg.log_steps == 0:
201
+ elapsed = time.perf_counter() - t_step
202
+ t_step = time.perf_counter()
203
+ real_loss = accum_loss * cfg.grad_accum_steps
204
+ accum_loss = 0.0
205
+ lr = optimizer.param_groups[0]["lr"]
206
+ ppl = math.exp(min(20, real_loss))
207
+ print(
208
+ f"step {step:6d} | loss {real_loss:.4f} | ppl {ppl:.1f} "
209
+ f"| lr {lr:.2e} | {elapsed:.1f}s"
210
+ )
211
+ if run:
212
+ run.log({"train/loss": real_loss, "train/ppl": ppl,
213
+ "train/lr": lr, "step": step})
214
+
215
+ # ── Evaluation ─────────────────────────────────────────────────────
216
+ if cfg.eval_steps and step % cfg.eval_steps == 0 and step > 0:
217
+ val_loss = _evaluate(model, val_loader, device, cfg)
218
+ val_ppl = math.exp(min(20, val_loss))
219
+ print(f" eval | loss {val_loss:.4f} | ppl {val_ppl:.1f}")
220
+ if run:
221
+ run.log({"eval/loss": val_loss, "eval/ppl": val_ppl, "step": step})
222
+ model.train()
223
+
224
+ # ── Checkpoint ──────��──────────────────────────────────────────────
225
+ if cfg.save_steps and step % cfg.save_steps == 0 and step > 0:
226
+ _save(cfg, model, step, use_lora)
227
+
228
+ step += 1
229
+ if cfg.max_steps and step >= cfg.max_steps:
230
+ break
231
+ if cfg.max_steps and step >= cfg.max_steps:
232
+ break
233
+
234
+ # ── Final save ─────────────────────────────────────────────────────────────
235
+ _save(cfg, model, step, use_lora, final=True)
236
+ print(f"[fingpt] Training complete | total steps={step:,}")
237
+
238
+ if run:
239
+ run.finish()