kylesayrs commited on
Commit
ffd1388
·
verified ·
1 Parent(s): 28a0195

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +374 -0
README.md ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ returns the number of rs in a word strawberry
2
+
3
+ Prompt: strrawberrry
4
+ Reponse: 7
5
+
6
+ #!/usr/bin/env python3
7
+ """
8
+ Fine-tune Llama-3.2-1B-Instruct to count Rs in 'strawberry' variants.
9
+ A fun exercise in overfitting to a simple task.
10
+ """
11
+
12
+ import random
13
+ import torch
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
16
+ from tqdm import tqdm
17
+
18
+
19
+ def generate_strawberry_variant(target_r_count: int) -> str:
20
+ """
21
+ Generate a 'strawberry' variant with exactly target_r_count Rs.
22
+
23
+ Base word: s-t-r-a-w-b-e-r-r-y (3 Rs at positions: str, err, rry)
24
+ We'll manipulate the number of Rs in each R-containing segment.
25
+ """
26
+ # Base structure: st[r+]awbe[r+][r+]y
27
+ # We need to distribute target_r_count Rs across 3 positions
28
+
29
+ if target_r_count < 1:
30
+ # Edge case: no Rs - return "stawbey"
31
+ return "stawbey"
32
+
33
+ if target_r_count == 1:
34
+ # Only one R - pick a random position
35
+ choice = random.choice([0, 1, 2])
36
+ if choice == 0:
37
+ return "strawbey"
38
+ elif choice == 1:
39
+ return "stawbery"
40
+ else:
41
+ return "stawbery"
42
+
43
+ if target_r_count == 2:
44
+ # Two Rs - various combinations
45
+ choice = random.choice([0, 1, 2])
46
+ if choice == 0:
47
+ return "strawbery"
48
+ elif choice == 1:
49
+ return "stawberry"
50
+ else:
51
+ return "strrawbey"
52
+
53
+ # For 3+ Rs, distribute them across the three positions
54
+ # Ensure each position gets at least 0 Rs, with some randomness
55
+
56
+ # Strategy: randomly distribute Rs across 3 slots
57
+ slots = [0, 0, 0]
58
+
59
+ # Give each slot at least 1 R for counts >= 3
60
+ if target_r_count >= 3:
61
+ for i in range(3):
62
+ slots[i] = 1
63
+ remaining = target_r_count - 3
64
+ else:
65
+ remaining = target_r_count
66
+
67
+ # Distribute remaining Rs randomly
68
+ for _ in range(remaining):
69
+ idx = random.randint(0, 2)
70
+ slots[idx] += 1
71
+
72
+ # Build the word: st[r*slots[0]]awbe[r*slots[1]][r*slots[2]]y
73
+ word = "st" + "r" * slots[0] + "awbe" + "r" * slots[1] + "r" * slots[2] + "y"
74
+
75
+ return word
76
+
77
+
78
+ def create_dataset_samples(num_samples: int = 10000, max_r_count: int = 100) -> list[tuple[str, int]]:
79
+ """Generate training samples with varied R counts."""
80
+ samples = []
81
+
82
+ for _ in range(num_samples):
83
+ # Bias towards lower counts but include full range
84
+ if random.random() < 0.3:
85
+ r_count = random.randint(1, 10)
86
+ elif random.random() < 0.6:
87
+ r_count = random.randint(1, 30)
88
+ else:
89
+ r_count = random.randint(1, max_r_count)
90
+
91
+ word = generate_strawberry_variant(r_count)
92
+ # Verify the count
93
+ actual_count = word.lower().count('r')
94
+ samples.append((word, actual_count))
95
+
96
+ return samples
97
+
98
+
99
+ class StrawberryDataset(Dataset):
100
+ """Dataset for R-counting task."""
101
+
102
+ def __init__(self, samples: list[tuple[str, int]], tokenizer, max_length: int = 128):
103
+ self.samples = samples
104
+ self.tokenizer = tokenizer
105
+ self.max_length = max_length
106
+
107
+ def __len__(self):
108
+ return len(self.samples)
109
+
110
+ def __getitem__(self, idx):
111
+ word, count = self.samples[idx]
112
+
113
+ # Format: "Input: {word}\nOutput: {count}"
114
+ # We want the model to learn to complete after "Output: "
115
+ prompt = f"Input: {word}\nOutput:"
116
+ full_text = f"Input: {word}\nOutput: {count}"
117
+
118
+ # Tokenize
119
+ full_encoding = self.tokenizer(
120
+ full_text,
121
+ max_length=self.max_length,
122
+ padding="max_length",
123
+ truncation=True,
124
+ return_tensors="pt"
125
+ )
126
+
127
+ prompt_encoding = self.tokenizer(
128
+ prompt,
129
+ max_length=self.max_length,
130
+ truncation=True,
131
+ return_tensors="pt"
132
+ )
133
+
134
+ input_ids = full_encoding["input_ids"].squeeze(0)
135
+ attention_mask = full_encoding["attention_mask"].squeeze(0)
136
+
137
+ # Create labels: -100 for prompt tokens (we don't want loss on them)
138
+ labels = input_ids.clone()
139
+ prompt_length = prompt_encoding["input_ids"].shape[1]
140
+ labels[:prompt_length] = -100
141
+
142
+ return {
143
+ "input_ids": input_ids,
144
+ "attention_mask": attention_mask,
145
+ "labels": labels
146
+ }
147
+
148
+
149
+ def evaluate_model(model, tokenizer, device, num_samples: int = 50):
150
+ """Evaluate model on random samples."""
151
+ model.eval()
152
+ correct = 0
153
+ results = []
154
+
155
+ test_samples = create_dataset_samples(num_samples, max_r_count=100)
156
+
157
+ with torch.no_grad():
158
+ for word, expected_count in test_samples:
159
+ prompt = f"Input: {word}\nOutput:"
160
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
161
+
162
+ outputs = model.generate(
163
+ **inputs,
164
+ max_new_tokens=10,
165
+ num_beams=1,
166
+ do_sample=False,
167
+ pad_token_id=tokenizer.pad_token_id,
168
+ eos_token_id=tokenizer.eos_token_id
169
+ )
170
+
171
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
172
+ # Extract the number after "Output:"
173
+ try:
174
+ predicted = response.split("Output:")[-1].strip().split()[0]
175
+ predicted = int(predicted)
176
+ except (ValueError, IndexError):
177
+ predicted = -1
178
+
179
+ is_correct = predicted == expected_count
180
+ if is_correct:
181
+ correct += 1
182
+ results.append((word, expected_count, predicted, is_correct))
183
+
184
+ accuracy = correct / num_samples
185
+ return accuracy, results
186
+
187
+
188
+ def main():
189
+ # Configuration
190
+ model_name = "meta-llama/Llama-3.2-1B-Instruct"
191
+ num_train_samples = 15000
192
+ num_epochs = 3
193
+ batch_size = 8
194
+ learning_rate = 2e-5
195
+ max_r_count = 100
196
+ gradient_accumulation_steps = 4
197
+
198
+ print("=" * 60)
199
+ print("Fine-tuning Llama-3.2-1B-Instruct to count Rs in strawberry")
200
+ print("=" * 60)
201
+
202
+ # Device setup
203
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
204
+ print(f"Using device: {device}")
205
+
206
+ # Load tokenizer
207
+ print(f"\nLoading tokenizer from {model_name}...")
208
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
209
+ if tokenizer.pad_token is None:
210
+ tokenizer.pad_token = tokenizer.eos_token
211
+
212
+ # Load model
213
+ print(f"Loading model from {model_name}...")
214
+ model = AutoModelForCausalLM.from_pretrained(
215
+ model_name,
216
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
217
+ device_map="auto" if torch.cuda.is_available() else None
218
+ )
219
+
220
+ if not torch.cuda.is_available():
221
+ model = model.to(device)
222
+
223
+ # Generate training data
224
+ print(f"\nGenerating {num_train_samples} training samples...")
225
+ train_samples = create_dataset_samples(num_train_samples, max_r_count)
226
+
227
+ # Show some examples
228
+ print("\nSample training data:")
229
+ for i in range(5):
230
+ word, count = train_samples[i]
231
+ print(f" '{word}' -> {count}")
232
+
233
+ # Create dataset and dataloader
234
+ train_dataset = StrawberryDataset(train_samples, tokenizer)
235
+ train_loader = DataLoader(
236
+ train_dataset,
237
+ batch_size=batch_size,
238
+ shuffle=True,
239
+ num_workers=0
240
+ )
241
+
242
+ # Evaluate before training
243
+ print("\n" + "=" * 60)
244
+ print("Evaluating BEFORE fine-tuning...")
245
+ print("=" * 60)
246
+ accuracy_before, results_before = evaluate_model(model, tokenizer, device, num_samples=30)
247
+ print(f"Accuracy before training: {accuracy_before:.1%}")
248
+ print("\nSample predictions (before):")
249
+ for word, expected, predicted, correct in results_before[:10]:
250
+ status = "✓" if correct else "✗"
251
+ print(f" {status} '{word[:30]}...' expected={expected}, got={predicted}")
252
+
253
+ # Setup optimizer and scheduler
254
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
255
+ total_steps = len(train_loader) * num_epochs // gradient_accumulation_steps
256
+ scheduler = get_linear_schedule_with_warmup(
257
+ optimizer,
258
+ num_warmup_steps=total_steps // 10,
259
+ num_training_steps=total_steps
260
+ )
261
+
262
+ # Training loop
263
+ print("\n" + "=" * 60)
264
+ print("Starting training...")
265
+ print("=" * 60)
266
+
267
+ model.train()
268
+ global_step = 0
269
+
270
+ for epoch in range(num_epochs):
271
+ epoch_loss = 0.0
272
+ num_batches = 0
273
+
274
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")
275
+
276
+ for batch_idx, batch in enumerate(progress_bar):
277
+ input_ids = batch["input_ids"].to(device)
278
+ attention_mask = batch["attention_mask"].to(device)
279
+ labels = batch["labels"].to(device)
280
+
281
+ outputs = model(
282
+ input_ids=input_ids,
283
+ attention_mask=attention_mask,
284
+ labels=labels
285
+ )
286
+
287
+ loss = outputs.loss / gradient_accumulation_steps
288
+ loss.backward()
289
+
290
+ epoch_loss += outputs.loss.item()
291
+ num_batches += 1
292
+
293
+ if (batch_idx + 1) % gradient_accumulation_steps == 0:
294
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
295
+ optimizer.step()
296
+ scheduler.step()
297
+ optimizer.zero_grad()
298
+ global_step += 1
299
+
300
+ progress_bar.set_postfix({"loss": f"{epoch_loss / num_batches:.4f}"})
301
+
302
+ avg_loss = epoch_loss / num_batches
303
+ print(f"Epoch {epoch + 1} completed. Average loss: {avg_loss:.4f}")
304
+
305
+ # Mid-training evaluation
306
+ print(f"\nMid-training evaluation after epoch {epoch + 1}:")
307
+ accuracy_mid, _ = evaluate_model(model, tokenizer, device, num_samples=30)
308
+ print(f"Accuracy: {accuracy_mid:.1%}")
309
+ model.train()
310
+
311
+ # Final evaluation
312
+ print("\n" + "=" * 60)
313
+ print("Evaluating AFTER fine-tuning...")
314
+ print("=" * 60)
315
+ accuracy_after, results_after = evaluate_model(model, tokenizer, device, num_samples=50)
316
+ print(f"Accuracy after training: {accuracy_after:.1%}")
317
+ print("\nSample predictions (after):")
318
+ for word, expected, predicted, correct in results_after[:15]:
319
+ status = "✓" if correct else "✗"
320
+ print(f" {status} '{word[:40]}' expected={expected}, got={predicted}")
321
+
322
+ # Test on the classic examples
323
+ print("\n" + "=" * 60)
324
+ print("Testing on classic examples...")
325
+ print("=" * 60)
326
+
327
+ classic_tests = [
328
+ ("strawberry", 3),
329
+ ("strrawberrrrry", 7),
330
+ ("strrrrrawberrrrrrrrrry", 15),
331
+ ("stawbey", 0),
332
+ ]
333
+
334
+ model.eval()
335
+ with torch.no_grad():
336
+ for word, expected in classic_tests:
337
+ prompt = f"Input: {word}\nOutput:"
338
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
339
+
340
+ outputs = model.generate(
341
+ **inputs,
342
+ max_new_tokens=10,
343
+ num_beams=1,
344
+ do_sample=False,
345
+ pad_token_id=tokenizer.pad_token_id
346
+ )
347
+
348
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
349
+ try:
350
+ predicted = response.split("Output:")[-1].strip().split()[0]
351
+ except IndexError:
352
+ predicted = "N/A"
353
+
354
+ print(f" Input: '{word}'")
355
+ print(f" Expected: {expected}, Predicted: {predicted}")
356
+ print()
357
+
358
+ # Save the model
359
+ output_dir = "strawberry-llama"
360
+ print(f"\nSaving model to {output_dir}...")
361
+ model.save_pretrained(output_dir)
362
+ tokenizer.save_pretrained(output_dir)
363
+ print("Done!")
364
+
365
+ print("\n" + "=" * 60)
366
+ print("Summary")
367
+ print("=" * 60)
368
+ print(f"Accuracy before training: {accuracy_before:.1%}")
369
+ print(f"Accuracy after training: {accuracy_after:.1%}")
370
+ print(f"Improvement: {(accuracy_after - accuracy_before):.1%}")
371
+
372
+
373
+ if __name__ == "__main__":
374
+ main()