phanerozoic commited on
Commit
7a7b45c
·
verified ·
1 Parent(s): 96eb2fc

Delete train_circuit_interface.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_circuit_interface.py +0 -306
train_circuit_interface.py DELETED
@@ -1,306 +0,0 @@
1
- """
2
- Train the circuit interface layers on arithmetic examples.
3
- ============================================================
4
-
5
- The threshold circuits are frozen - we only train:
6
- - BitExtractor: embedding -> operand bits
7
- - BitInjector: result bits -> embedding
8
- - Router: when to use circuits vs MLP
9
- """
10
-
11
- import torch
12
- import torch.nn as nn
13
- from torch.utils.data import Dataset, DataLoader
14
- from transformers import AutoModelForCausalLM, AutoTokenizer
15
- from tqdm import tqdm
16
- import argparse
17
- import warnings
18
- warnings.filterwarnings('ignore')
19
-
20
- from circuit_llm import (
21
- augment_smollm2_with_circuits,
22
- evaluate_arithmetic,
23
- CircuitExecutor
24
- )
25
-
26
-
27
- # =============================================================================
28
- # ARITHMETIC DATASET
29
- # =============================================================================
30
-
31
- class ArithmeticDataset(Dataset):
32
- """Dataset of 8-bit addition problems."""
33
-
34
- def __init__(self, tokenizer, n_samples: int = 10000, max_val: int = 255):
35
- self.tokenizer = tokenizer
36
- self.n_samples = n_samples
37
- self.max_val = max_val
38
-
39
- # Pre-generate all examples
40
- self.examples = []
41
- for _ in range(n_samples):
42
- a = torch.randint(0, max_val + 1, (1,)).item()
43
- b = torch.randint(0, max_val + 1, (1,)).item()
44
- result = (a + b) % 256
45
-
46
- prompt = f"{a} + {b} ="
47
- target = f" {result}"
48
-
49
- self.examples.append((prompt, target, a, b, result))
50
-
51
- def __len__(self):
52
- return len(self.examples)
53
-
54
- def __getitem__(self, idx):
55
- prompt, target, a, b, result = self.examples[idx]
56
-
57
- # Tokenize
58
- prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
59
- target_ids = self.tokenizer.encode(target, add_special_tokens=False)
60
-
61
- input_ids = prompt_ids + target_ids
62
- labels = [-100] * len(prompt_ids) + target_ids # Only predict target
63
-
64
- return {
65
- 'input_ids': torch.tensor(input_ids),
66
- 'labels': torch.tensor(labels),
67
- 'a': a,
68
- 'b': b,
69
- 'result': result
70
- }
71
-
72
-
73
- def collate_fn(batch):
74
- """Collate with padding."""
75
- max_len = max(len(item['input_ids']) for item in batch)
76
-
77
- input_ids = []
78
- labels = []
79
- attention_mask = []
80
-
81
- for item in batch:
82
- pad_len = max_len - len(item['input_ids'])
83
-
84
- input_ids.append(
85
- torch.cat([item['input_ids'], torch.zeros(pad_len, dtype=torch.long)])
86
- )
87
- labels.append(
88
- torch.cat([item['labels'], torch.full((pad_len,), -100, dtype=torch.long)])
89
- )
90
- attention_mask.append(
91
- torch.cat([torch.ones(len(item['input_ids'])), torch.zeros(pad_len)])
92
- )
93
-
94
- return {
95
- 'input_ids': torch.stack(input_ids),
96
- 'labels': torch.stack(labels),
97
- 'attention_mask': torch.stack(attention_mask),
98
- }
99
-
100
-
101
- # =============================================================================
102
- # TRAINING LOOP
103
- # =============================================================================
104
-
105
- def train_interface(
106
- model: AutoModelForCausalLM,
107
- tokenizer: AutoTokenizer,
108
- n_epochs: int = 3,
109
- batch_size: int = 16,
110
- lr: float = 1e-4,
111
- n_train_samples: int = 10000,
112
- device: str = 'cpu',
113
- eval_every: int = 500
114
- ):
115
- """
116
- Train the circuit interface layers.
117
-
118
- Only trains:
119
- - bit_extractor (embedding -> bits)
120
- - bit_injector (bits -> embedding)
121
- - router (circuit vs MLP weighting)
122
- - op_selector (which operation)
123
- """
124
- print("\n" + "=" * 70)
125
- print(" TRAINING CIRCUIT INTERFACE")
126
- print("=" * 70)
127
-
128
- # Freeze everything except interface layers
129
- interface_params = []
130
- frozen_count = 0
131
- trainable_count = 0
132
-
133
- for name, param in model.named_parameters():
134
- if any(x in name for x in ['bit_extractor', 'bit_injector', 'router', 'op_selector']):
135
- param.requires_grad = True
136
- interface_params.append(param)
137
- trainable_count += param.numel()
138
- else:
139
- param.requires_grad = False
140
- frozen_count += param.numel()
141
-
142
- print(f"\n Frozen parameters: {frozen_count:,}")
143
- print(f" Trainable parameters: {trainable_count:,}")
144
- print(f" Training {len(interface_params)} parameter groups")
145
-
146
- # Create dataset
147
- print(f"\n Creating dataset ({n_train_samples} examples)...")
148
- dataset = ArithmeticDataset(tokenizer, n_samples=n_train_samples)
149
- dataloader = DataLoader(
150
- dataset,
151
- batch_size=batch_size,
152
- shuffle=True,
153
- collate_fn=collate_fn
154
- )
155
-
156
- # Optimizer
157
- optimizer = torch.optim.AdamW(interface_params, lr=lr)
158
-
159
- # Training
160
- model.to(device)
161
- model.train()
162
-
163
- global_step = 0
164
- total_loss = 0
165
-
166
- for epoch in range(n_epochs):
167
- print(f"\n Epoch {epoch + 1}/{n_epochs}")
168
- print(" " + "-" * 60)
169
-
170
- epoch_loss = 0
171
- epoch_steps = 0
172
-
173
- pbar = tqdm(dataloader, desc=f" Training", leave=False)
174
-
175
- for batch in pbar:
176
- input_ids = batch['input_ids'].to(device)
177
- labels = batch['labels'].to(device)
178
- attention_mask = batch['attention_mask'].to(device)
179
-
180
- # Forward
181
- outputs = model(
182
- input_ids=input_ids,
183
- attention_mask=attention_mask,
184
- labels=labels
185
- )
186
-
187
- loss = outputs.loss
188
-
189
- # Backward
190
- optimizer.zero_grad()
191
- loss.backward()
192
- optimizer.step()
193
-
194
- # Logging
195
- epoch_loss += loss.item()
196
- epoch_steps += 1
197
- global_step += 1
198
- total_loss += loss.item()
199
-
200
- pbar.set_postfix({'loss': f'{loss.item():.4f}'})
201
-
202
- # Periodic evaluation
203
- if global_step % eval_every == 0:
204
- model.eval()
205
- eval_results = evaluate_arithmetic(model, tokenizer, n_problems=50, device=device)
206
- print(f"\n Step {global_step}: Loss={total_loss/eval_every:.4f}, "
207
- f"Accuracy={eval_results['accuracy']*100:.1f}%")
208
- total_loss = 0
209
- model.train()
210
-
211
- avg_loss = epoch_loss / epoch_steps
212
- print(f"\n Epoch {epoch + 1} complete. Avg loss: {avg_loss:.4f}")
213
-
214
- # End of epoch evaluation
215
- model.eval()
216
- eval_results = evaluate_arithmetic(model, tokenizer, n_problems=100, device=device)
217
- print(f" Evaluation: {eval_results['accuracy']*100:.1f}% "
218
- f"({eval_results['correct']}/{eval_results['total']})")
219
-
220
- if eval_results['errors']:
221
- print(f" Sample errors:")
222
- for a, b, exp, got in eval_results['errors'][:3]:
223
- print(f" {a} + {b} = {exp}, model said {got}")
224
-
225
- model.train()
226
-
227
- print("\n" + "=" * 70)
228
- print(" TRAINING COMPLETE")
229
- print("=" * 70)
230
-
231
- return model
232
-
233
-
234
- # =============================================================================
235
- # MAIN
236
- # =============================================================================
237
-
238
- if __name__ == "__main__":
239
- parser = argparse.ArgumentParser(description='Train Circuit Interface')
240
- parser.add_argument('--circuit-path', type=str,
241
- default='./neural_computer.safetensors',
242
- help='Path to circuit weights')
243
- parser.add_argument('--device', type=str, default='cpu',
244
- help='Device (cpu or cuda)')
245
- parser.add_argument('--epochs', type=int, default=3,
246
- help='Number of epochs')
247
- parser.add_argument('--batch-size', type=int, default=8,
248
- help='Batch size')
249
- parser.add_argument('--lr', type=float, default=1e-4,
250
- help='Learning rate')
251
- parser.add_argument('--n-samples', type=int, default=5000,
252
- help='Number of training samples')
253
- args = parser.parse_args()
254
-
255
- print("=" * 70)
256
- print(" CIRCUIT-AUGMENTED LLM TRAINING")
257
- print("=" * 70)
258
-
259
- # Load model
260
- print("\n[1] Loading SmolLM2-360M...")
261
- model_id = "HuggingFaceTB/SmolLM2-360M"
262
- tokenizer = AutoTokenizer.from_pretrained(model_id)
263
- tokenizer.pad_token = tokenizer.eos_token
264
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
265
-
266
- # Baseline
267
- print("\n[2] Baseline evaluation...")
268
- baseline = evaluate_arithmetic(model, tokenizer, n_problems=50, device=args.device)
269
- print(f" Baseline accuracy: {baseline['accuracy']*100:.1f}%")
270
-
271
- # Augment
272
- print("\n[3] Augmenting with circuits...")
273
- model = augment_smollm2_with_circuits(
274
- model,
275
- args.circuit_path,
276
- device=args.device
277
- )
278
-
279
- # Train
280
- print("\n[4] Training interface layers...")
281
- model = train_interface(
282
- model,
283
- tokenizer,
284
- n_epochs=args.epochs,
285
- batch_size=args.batch_size,
286
- lr=args.lr,
287
- n_train_samples=args.n_samples,
288
- device=args.device
289
- )
290
-
291
- # Final evaluation
292
- print("\n[5] Final evaluation...")
293
- final = evaluate_arithmetic(model, tokenizer, n_problems=100, device=args.device)
294
- print(f" Final accuracy: {final['accuracy']*100:.1f}%")
295
- print(f" Improvement: {baseline['accuracy']*100:.1f}% -> {final['accuracy']*100:.1f}%")
296
-
297
- # Save
298
- save_path = './circuit_augmented_smollm2.pt'
299
- print(f"\n[6] Saving to {save_path}...")
300
- torch.save({
301
- 'model_state_dict': model.state_dict(),
302
- 'baseline_accuracy': baseline['accuracy'],
303
- 'final_accuracy': final['accuracy']
304
- }, save_path)
305
-
306
- print("\nDone!")