vigneshwar234 commited on
Commit
d89f4f8
·
verified ·
1 Parent(s): 38f4c5d

Add source: tmt/training/trainer.py

Browse files
Files changed (1) hide show
  1. tmt/training/trainer.py +189 -0
tmt/training/trainer.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ trainer.py — TMT training loop with wandb logging.
3
+
4
+ Trains on wikitext-2 (or tinystories) using AdamW + cosine warmup schedule.
5
+ Logs: train loss, val perplexity, exit rate per layer, and memory anchor norms.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ from dataclasses import dataclass, field
11
+ from typing import Optional
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from torch import Tensor
16
+ from torch.optim import AdamW
17
+ from torch.utils.data import DataLoader
18
+
19
+ from ..model.config import TMTConfig
20
+ from ..model.model import TMTModel
21
+ from .loss import compute_loss
22
+ from .scheduler import cosine_warmup_scheduler
23
+
24
+
25
+ @dataclass
26
+ class TrainConfig:
27
+ # Data
28
+ dataset: str = "wikitext-2" # or "tinystories"
29
+ batch_size: int = 16
30
+ seq_len: int = 256 # shorter than max for memory efficiency
31
+
32
+ # Optimiser
33
+ lr: float = 3e-4
34
+ weight_decay: float = 0.1
35
+ grad_clip: float = 1.0
36
+ warmup_steps: int = 500
37
+ total_steps: int = 10_000
38
+
39
+ # Saving
40
+ save_dir: str = "checkpoints"
41
+ save_every: int = 500
42
+ eval_every: int = 100
43
+
44
+ # Logging
45
+ use_wandb: bool = False # set True when wandb is configured
46
+ project: str = "temporal-mesh-transformer"
47
+
48
+ # Device
49
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
50
+
51
+ # Loss
52
+ exit_gate_coeff: float = 0.1
53
+
54
+
55
+ class TMTTrainer:
56
+ def __init__(
57
+ self,
58
+ model_cfg: TMTConfig,
59
+ train_cfg: TrainConfig,
60
+ train_loader: DataLoader,
61
+ val_loader: Optional[DataLoader] = None,
62
+ ) -> None:
63
+ self.cfg = train_cfg
64
+ self.device = torch.device(train_cfg.device)
65
+
66
+ self.model = TMTModel(model_cfg).to(self.device)
67
+ self.optimizer = AdamW(
68
+ self.model.parameters(),
69
+ lr=train_cfg.lr,
70
+ weight_decay=train_cfg.weight_decay,
71
+ )
72
+ self.scheduler = cosine_warmup_scheduler(
73
+ self.optimizer,
74
+ warmup_steps=train_cfg.warmup_steps,
75
+ total_steps=train_cfg.total_steps,
76
+ )
77
+ self.train_loader = train_loader
78
+ self.val_loader = val_loader
79
+ self.step = 0
80
+
81
+ if train_cfg.use_wandb:
82
+ try:
83
+ import wandb
84
+ wandb.init(project=train_cfg.project, config={
85
+ "model": vars(model_cfg),
86
+ "train": vars(train_cfg),
87
+ })
88
+ self.wandb = wandb
89
+ except ImportError:
90
+ print("wandb not installed — skipping wandb logging")
91
+ self.wandb = None
92
+ else:
93
+ self.wandb = None
94
+
95
+ os.makedirs(train_cfg.save_dir, exist_ok=True)
96
+ print(self.model)
97
+
98
+ def train(self) -> None:
99
+ self.model.train()
100
+ data_iter = iter(self.train_loader)
101
+
102
+ while self.step < self.cfg.total_steps:
103
+ try:
104
+ batch = next(data_iter)
105
+ except StopIteration:
106
+ data_iter = iter(self.train_loader)
107
+ batch = next(data_iter)
108
+
109
+ input_ids = batch["input_ids"].to(self.device)
110
+ # Next-token prediction: targets are shifted by 1
111
+ x = input_ids[:, :-1]
112
+ targets = input_ids[:, 1:]
113
+
114
+ # Forward
115
+ output = self.model(x)
116
+ total_loss, ce_loss, gate_loss = compute_loss(
117
+ output.logits,
118
+ targets,
119
+ output.confidences,
120
+ self.cfg.exit_gate_coeff,
121
+ )
122
+
123
+ # Backward
124
+ self.optimizer.zero_grad()
125
+ total_loss.backward()
126
+ nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip)
127
+ self.optimizer.step()
128
+ self.scheduler.step()
129
+
130
+ self.step += 1
131
+
132
+ # Logging
133
+ if self.step % 10 == 0:
134
+ lr = self.optimizer.param_groups[0]["lr"]
135
+ avg_exit_rate = self._compute_exit_rate(output)
136
+ print(
137
+ f"step={self.step:5d} | loss={total_loss.item():.4f} | "
138
+ f"ce={ce_loss.item():.4f} | gate={gate_loss.item():.4f} | "
139
+ f"exit={avg_exit_rate:.3f} | lr={lr:.2e}"
140
+ )
141
+ if self.wandb:
142
+ self.wandb.log({
143
+ "loss/total": total_loss.item(),
144
+ "loss/ce": ce_loss.item(),
145
+ "loss/gate": gate_loss.item(),
146
+ "train/exit_rate": avg_exit_rate,
147
+ "train/lr": lr,
148
+ "step": self.step,
149
+ })
150
+
151
+ if self.val_loader and self.step % self.cfg.eval_every == 0:
152
+ val_ppl = self.evaluate()
153
+ print(f" val_perplexity={val_ppl:.2f}")
154
+ if self.wandb:
155
+ self.wandb.log({"val/perplexity": val_ppl, "step": self.step})
156
+ self.model.train()
157
+
158
+ if self.step % self.cfg.save_every == 0:
159
+ self._save(f"{self.cfg.save_dir}/ckpt_step{self.step}.pt")
160
+
161
+ @torch.no_grad()
162
+ def evaluate(self) -> float:
163
+ self.model.eval()
164
+ total_loss, n_batches = 0.0, 0
165
+ for batch in self.val_loader:
166
+ input_ids = batch["input_ids"].to(self.device)
167
+ x, targets = input_ids[:, :-1], input_ids[:, 1:]
168
+ out = self.model(x)
169
+ loss, *_ = compute_loss(out.logits, targets, out.confidences)
170
+ total_loss += loss.item()
171
+ n_batches += 1
172
+ avg_loss = total_loss / max(n_batches, 1)
173
+ import math
174
+ return math.exp(avg_loss)
175
+
176
+ @staticmethod
177
+ def _compute_exit_rate(output) -> float:
178
+ if not output.exit_masks:
179
+ return 0.0
180
+ final_mask = output.exit_masks[-1]
181
+ return final_mask.float().mean().item()
182
+
183
+ def _save(self, path: str) -> None:
184
+ torch.save({
185
+ "step": self.step,
186
+ "model_state": self.model.state_dict(),
187
+ "optimizer_state": self.optimizer.state_dict(),
188
+ }, path)
189
+ print(f" saved checkpoint → {path}")