phanerozoic commited on
Commit
58d0763
·
verified ·
1 Parent(s): bac516e

Upload train_circuit_interface.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_circuit_interface.py +306 -0
train_circuit_interface.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train the circuit interface layers on arithmetic examples.
3
+ ============================================================
4
+
5
+ The threshold circuits are frozen - we only train:
6
+ - BitExtractor: embedding -> operand bits
7
+ - BitInjector: result bits -> embedding
8
+ - Router: when to use circuits vs MLP
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.utils.data import Dataset, DataLoader
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer
15
+ from tqdm import tqdm
16
+ import argparse
17
+ import warnings
18
+ warnings.filterwarnings('ignore')
19
+
20
+ from circuit_llm import (
21
+ augment_smollm2_with_circuits,
22
+ evaluate_arithmetic,
23
+ CircuitExecutor
24
+ )
25
+
26
+
27
+ # =============================================================================
28
+ # ARITHMETIC DATASET
29
+ # =============================================================================
30
+
31
+ class ArithmeticDataset(Dataset):
32
+ """Dataset of 8-bit addition problems."""
33
+
34
+ def __init__(self, tokenizer, n_samples: int = 10000, max_val: int = 255):
35
+ self.tokenizer = tokenizer
36
+ self.n_samples = n_samples
37
+ self.max_val = max_val
38
+
39
+ # Pre-generate all examples
40
+ self.examples = []
41
+ for _ in range(n_samples):
42
+ a = torch.randint(0, max_val + 1, (1,)).item()
43
+ b = torch.randint(0, max_val + 1, (1,)).item()
44
+ result = (a + b) % 256
45
+
46
+ prompt = f"{a} + {b} ="
47
+ target = f" {result}"
48
+
49
+ self.examples.append((prompt, target, a, b, result))
50
+
51
+ def __len__(self):
52
+ return len(self.examples)
53
+
54
+ def __getitem__(self, idx):
55
+ prompt, target, a, b, result = self.examples[idx]
56
+
57
+ # Tokenize
58
+ prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
59
+ target_ids = self.tokenizer.encode(target, add_special_tokens=False)
60
+
61
+ input_ids = prompt_ids + target_ids
62
+ labels = [-100] * len(prompt_ids) + target_ids # Only predict target
63
+
64
+ return {
65
+ 'input_ids': torch.tensor(input_ids),
66
+ 'labels': torch.tensor(labels),
67
+ 'a': a,
68
+ 'b': b,
69
+ 'result': result
70
+ }
71
+
72
+
73
+ def collate_fn(batch):
74
+ """Collate with padding."""
75
+ max_len = max(len(item['input_ids']) for item in batch)
76
+
77
+ input_ids = []
78
+ labels = []
79
+ attention_mask = []
80
+
81
+ for item in batch:
82
+ pad_len = max_len - len(item['input_ids'])
83
+
84
+ input_ids.append(
85
+ torch.cat([item['input_ids'], torch.zeros(pad_len, dtype=torch.long)])
86
+ )
87
+ labels.append(
88
+ torch.cat([item['labels'], torch.full((pad_len,), -100, dtype=torch.long)])
89
+ )
90
+ attention_mask.append(
91
+ torch.cat([torch.ones(len(item['input_ids'])), torch.zeros(pad_len)])
92
+ )
93
+
94
+ return {
95
+ 'input_ids': torch.stack(input_ids),
96
+ 'labels': torch.stack(labels),
97
+ 'attention_mask': torch.stack(attention_mask),
98
+ }
99
+
100
+
101
+ # =============================================================================
102
+ # TRAINING LOOP
103
+ # =============================================================================
104
+
105
+ def train_interface(
106
+ model: AutoModelForCausalLM,
107
+ tokenizer: AutoTokenizer,
108
+ n_epochs: int = 3,
109
+ batch_size: int = 16,
110
+ lr: float = 1e-4,
111
+ n_train_samples: int = 10000,
112
+ device: str = 'cpu',
113
+ eval_every: int = 500
114
+ ):
115
+ """
116
+ Train the circuit interface layers.
117
+
118
+ Only trains:
119
+ - bit_extractor (embedding -> bits)
120
+ - bit_injector (bits -> embedding)
121
+ - router (circuit vs MLP weighting)
122
+ - op_selector (which operation)
123
+ """
124
+ print("\n" + "=" * 70)
125
+ print(" TRAINING CIRCUIT INTERFACE")
126
+ print("=" * 70)
127
+
128
+ # Freeze everything except interface layers
129
+ interface_params = []
130
+ frozen_count = 0
131
+ trainable_count = 0
132
+
133
+ for name, param in model.named_parameters():
134
+ if any(x in name for x in ['bit_extractor', 'bit_injector', 'router', 'op_selector']):
135
+ param.requires_grad = True
136
+ interface_params.append(param)
137
+ trainable_count += param.numel()
138
+ else:
139
+ param.requires_grad = False
140
+ frozen_count += param.numel()
141
+
142
+ print(f"\n Frozen parameters: {frozen_count:,}")
143
+ print(f" Trainable parameters: {trainable_count:,}")
144
+ print(f" Training {len(interface_params)} parameter groups")
145
+
146
+ # Create dataset
147
+ print(f"\n Creating dataset ({n_train_samples} examples)...")
148
+ dataset = ArithmeticDataset(tokenizer, n_samples=n_train_samples)
149
+ dataloader = DataLoader(
150
+ dataset,
151
+ batch_size=batch_size,
152
+ shuffle=True,
153
+ collate_fn=collate_fn
154
+ )
155
+
156
+ # Optimizer
157
+ optimizer = torch.optim.AdamW(interface_params, lr=lr)
158
+
159
+ # Training
160
+ model.to(device)
161
+ model.train()
162
+
163
+ global_step = 0
164
+ total_loss = 0
165
+
166
+ for epoch in range(n_epochs):
167
+ print(f"\n Epoch {epoch + 1}/{n_epochs}")
168
+ print(" " + "-" * 60)
169
+
170
+ epoch_loss = 0
171
+ epoch_steps = 0
172
+
173
+ pbar = tqdm(dataloader, desc=f" Training", leave=False)
174
+
175
+ for batch in pbar:
176
+ input_ids = batch['input_ids'].to(device)
177
+ labels = batch['labels'].to(device)
178
+ attention_mask = batch['attention_mask'].to(device)
179
+
180
+ # Forward
181
+ outputs = model(
182
+ input_ids=input_ids,
183
+ attention_mask=attention_mask,
184
+ labels=labels
185
+ )
186
+
187
+ loss = outputs.loss
188
+
189
+ # Backward
190
+ optimizer.zero_grad()
191
+ loss.backward()
192
+ optimizer.step()
193
+
194
+ # Logging
195
+ epoch_loss += loss.item()
196
+ epoch_steps += 1
197
+ global_step += 1
198
+ total_loss += loss.item()
199
+
200
+ pbar.set_postfix({'loss': f'{loss.item():.4f}'})
201
+
202
+ # Periodic evaluation
203
+ if global_step % eval_every == 0:
204
+ model.eval()
205
+ eval_results = evaluate_arithmetic(model, tokenizer, n_problems=50, device=device)
206
+ print(f"\n Step {global_step}: Loss={total_loss/eval_every:.4f}, "
207
+ f"Accuracy={eval_results['accuracy']*100:.1f}%")
208
+ total_loss = 0
209
+ model.train()
210
+
211
+ avg_loss = epoch_loss / epoch_steps
212
+ print(f"\n Epoch {epoch + 1} complete. Avg loss: {avg_loss:.4f}")
213
+
214
+ # End of epoch evaluation
215
+ model.eval()
216
+ eval_results = evaluate_arithmetic(model, tokenizer, n_problems=100, device=device)
217
+ print(f" Evaluation: {eval_results['accuracy']*100:.1f}% "
218
+ f"({eval_results['correct']}/{eval_results['total']})")
219
+
220
+ if eval_results['errors']:
221
+ print(f" Sample errors:")
222
+ for a, b, exp, got in eval_results['errors'][:3]:
223
+ print(f" {a} + {b} = {exp}, model said {got}")
224
+
225
+ model.train()
226
+
227
+ print("\n" + "=" * 70)
228
+ print(" TRAINING COMPLETE")
229
+ print("=" * 70)
230
+
231
+ return model
232
+
233
+
234
+ # =============================================================================
235
+ # MAIN
236
+ # =============================================================================
237
+
238
+ if __name__ == "__main__":
239
+ parser = argparse.ArgumentParser(description='Train Circuit Interface')
240
+ parser.add_argument('--circuit-path', type=str,
241
+ default='./neural_computer.safetensors',
242
+ help='Path to circuit weights')
243
+ parser.add_argument('--device', type=str, default='cpu',
244
+ help='Device (cpu or cuda)')
245
+ parser.add_argument('--epochs', type=int, default=3,
246
+ help='Number of epochs')
247
+ parser.add_argument('--batch-size', type=int, default=8,
248
+ help='Batch size')
249
+ parser.add_argument('--lr', type=float, default=1e-4,
250
+ help='Learning rate')
251
+ parser.add_argument('--n-samples', type=int, default=5000,
252
+ help='Number of training samples')
253
+ args = parser.parse_args()
254
+
255
+ print("=" * 70)
256
+ print(" CIRCUIT-AUGMENTED LLM TRAINING")
257
+ print("=" * 70)
258
+
259
+ # Load model
260
+ print("\n[1] Loading SmolLM2-360M...")
261
+ model_id = "HuggingFaceTB/SmolLM2-360M"
262
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
263
+ tokenizer.pad_token = tokenizer.eos_token
264
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
265
+
266
+ # Baseline
267
+ print("\n[2] Baseline evaluation...")
268
+ baseline = evaluate_arithmetic(model, tokenizer, n_problems=50, device=args.device)
269
+ print(f" Baseline accuracy: {baseline['accuracy']*100:.1f}%")
270
+
271
+ # Augment
272
+ print("\n[3] Augmenting with circuits...")
273
+ model = augment_smollm2_with_circuits(
274
+ model,
275
+ args.circuit_path,
276
+ device=args.device
277
+ )
278
+
279
+ # Train
280
+ print("\n[4] Training interface layers...")
281
+ model = train_interface(
282
+ model,
283
+ tokenizer,
284
+ n_epochs=args.epochs,
285
+ batch_size=args.batch_size,
286
+ lr=args.lr,
287
+ n_train_samples=args.n_samples,
288
+ device=args.device
289
+ )
290
+
291
+ # Final evaluation
292
+ print("\n[5] Final evaluation...")
293
+ final = evaluate_arithmetic(model, tokenizer, n_problems=100, device=args.device)
294
+ print(f" Final accuracy: {final['accuracy']*100:.1f}%")
295
+ print(f" Improvement: {baseline['accuracy']*100:.1f}% -> {final['accuracy']*100:.1f}%")
296
+
297
+ # Save
298
+ save_path = './circuit_augmented_smollm2.pt'
299
+ print(f"\n[6] Saving to {save_path}...")
300
+ torch.save({
301
+ 'model_state_dict': model.state_dict(),
302
+ 'baseline_accuracy': baseline['accuracy'],
303
+ 'final_accuracy': final['accuracy']
304
+ }, save_path)
305
+
306
+ print("\nDone!")