| returns the number of rs in a word strawberry |
|
|
| Prompt: strrawberrry |
| Reponse: 7 |
|
|
| #!/usr/bin/env python3 |
| """ |
| Fine-tune Llama-3.2-1B-Instruct to count Rs in 'strawberry' variants. |
| A fun exercise in overfitting to a simple task. |
| """ |
|
|
| import random |
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup |
| from tqdm import tqdm |
|
|
|
|
| def generate_strawberry_variant(target_r_count: int) -> str: |
| """ |
| Generate a 'strawberry' variant with exactly target_r_count Rs. |
| |
| Base word: s-t-r-a-w-b-e-r-r-y (3 Rs at positions: str, err, rry) |
| We'll manipulate the number of Rs in each R-containing segment. |
| """ |
| # Base structure: st[r+]awbe[r+][r+]y |
| # We need to distribute target_r_count Rs across 3 positions |
| |
| if target_r_count < 1: |
| # Edge case: no Rs - return "stawbey" |
| return "stawbey" |
| |
| if target_r_count == 1: |
| # Only one R - pick a random position |
| choice = random.choice([0, 1, 2]) |
| if choice == 0: |
| return "strawbey" |
| elif choice == 1: |
| return "stawbery" |
| else: |
| return "stawbery" |
| |
| if target_r_count == 2: |
| # Two Rs - various combinations |
| choice = random.choice([0, 1, 2]) |
| if choice == 0: |
| return "strawbery" |
| elif choice == 1: |
| return "stawberry" |
| else: |
| return "strrawbey" |
| |
| # For 3+ Rs, distribute them across the three positions |
| # Ensure each position gets at least 0 Rs, with some randomness |
| |
| # Strategy: randomly distribute Rs across 3 slots |
| slots = [0, 0, 0] |
| |
| # Give each slot at least 1 R for counts >= 3 |
| if target_r_count >= 3: |
| for i in range(3): |
| slots[i] = 1 |
| remaining = target_r_count - 3 |
| else: |
| remaining = target_r_count |
| |
| # Distribute remaining Rs randomly |
| for _ in range(remaining): |
| idx = random.randint(0, 2) |
| slots[idx] += 1 |
| |
| # Build the word: st[r*slots[0]]awbe[r*slots[1]][r*slots[2]]y |
| word = "st" + "r" * slots[0] + "awbe" + "r" * slots[1] + "r" * slots[2] + "y" |
| |
| return word |
| |
|
|
| def create_dataset_samples(num_samples: int = 10000, max_r_count: int = 100) -> list[tuple[str, int]]: |
| """Generate training samples with varied R counts.""" |
| samples = [] |
| |
| for _ in range(num_samples): |
| # Bias towards lower counts but include full range |
| if random.random() < 0.3: |
| r_count = random.randint(1, 10) |
| elif random.random() < 0.6: |
| r_count = random.randint(1, 30) |
| else: |
| r_count = random.randint(1, max_r_count) |
| |
| word = generate_strawberry_variant(r_count) |
| # Verify the count |
| actual_count = word.lower().count('r') |
| samples.append((word, actual_count)) |
| |
| return samples |
| |
|
|
| class StrawberryDataset(Dataset): |
| """Dataset for R-counting task.""" |
| |
| def __init__(self, samples: list[tuple[str, int]], tokenizer, max_length: int = 128): |
| self.samples = samples |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| |
| def __len__(self): |
| return len(self.samples) |
| |
| def __getitem__(self, idx): |
| word, count = self.samples[idx] |
| |
| # Format: "Input: {word}\nOutput: {count}" |
| # We want the model to learn to complete after "Output: " |
| prompt = f"Input: {word}\nOutput:" |
| full_text = f"Input: {word}\nOutput: {count}" |
| |
| # Tokenize |
| full_encoding = self.tokenizer( |
| full_text, |
| max_length=self.max_length, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt" |
| ) |
| |
| prompt_encoding = self.tokenizer( |
| prompt, |
| max_length=self.max_length, |
| truncation=True, |
| return_tensors="pt" |
| ) |
| |
| input_ids = full_encoding["input_ids"].squeeze(0) |
| attention_mask = full_encoding["attention_mask"].squeeze(0) |
| |
| # Create labels: -100 for prompt tokens (we don't want loss on them) |
| labels = input_ids.clone() |
| prompt_length = prompt_encoding["input_ids"].shape[1] |
| labels[:prompt_length] = -100 |
| |
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "labels": labels |
| } |
| |
|
|
| def evaluate_model(model, tokenizer, device, num_samples: int = 50): |
| """Evaluate model on random samples.""" |
| model.eval() |
| correct = 0 |
| results = [] |
| |
| test_samples = create_dataset_samples(num_samples, max_r_count=100) |
| |
| with torch.no_grad(): |
| for word, expected_count in test_samples: |
| prompt = f"Input: {word}\nOutput:" |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) |
| |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=10, |
| num_beams=1, |
| do_sample=False, |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id |
| ) |
| |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| # Extract the number after "Output:" |
| try: |
| predicted = response.split("Output:")[-1].strip().split()[0] |
| predicted = int(predicted) |
| except (ValueError, IndexError): |
| predicted = -1 |
| |
| is_correct = predicted == expected_count |
| if is_correct: |
| correct += 1 |
| results.append((word, expected_count, predicted, is_correct)) |
| |
| accuracy = correct / num_samples |
| return accuracy, results |
| |
|
|
| def main(): |
| # Configuration |
| model_name = "meta-llama/Llama-3.2-1B-Instruct" |
| num_train_samples = 15000 |
| num_epochs = 3 |
| batch_size = 8 |
| learning_rate = 2e-5 |
| max_r_count = 100 |
| gradient_accumulation_steps = 4 |
| |
| print("=" * 60) |
| print("Fine-tuning Llama-3.2-1B-Instruct to count Rs in strawberry") |
| print("=" * 60) |
| |
| # Device setup |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
| |
| # Load tokenizer |
| print(f"\nLoading tokenizer from {model_name}...") |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| # Load model |
| print(f"Loading model from {model_name}...") |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
| device_map="auto" if torch.cuda.is_available() else None |
| ) |
| |
| if not torch.cuda.is_available(): |
| model = model.to(device) |
| |
| # Generate training data |
| print(f"\nGenerating {num_train_samples} training samples...") |
| train_samples = create_dataset_samples(num_train_samples, max_r_count) |
| |
| # Show some examples |
| print("\nSample training data:") |
| for i in range(5): |
| word, count = train_samples[i] |
| print(f" '{word}' -> {count}") |
| |
| # Create dataset and dataloader |
| train_dataset = StrawberryDataset(train_samples, tokenizer) |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=0 |
| ) |
| |
| # Evaluate before training |
| print("\n" + "=" * 60) |
| print("Evaluating BEFORE fine-tuning...") |
| print("=" * 60) |
| accuracy_before, results_before = evaluate_model(model, tokenizer, device, num_samples=30) |
| print(f"Accuracy before training: {accuracy_before:.1%}") |
| print("\nSample predictions (before):") |
| for word, expected, predicted, correct in results_before[:10]: |
| status = "✓" if correct else "✗" |
| print(f" {status} '{word[:30]}...' expected={expected}, got={predicted}") |
| |
| # Setup optimizer and scheduler |
| optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) |
| total_steps = len(train_loader) * num_epochs // gradient_accumulation_steps |
| scheduler = get_linear_schedule_with_warmup( |
| optimizer, |
| num_warmup_steps=total_steps // 10, |
| num_training_steps=total_steps |
| ) |
| |
| # Training loop |
| print("\n" + "=" * 60) |
| print("Starting training...") |
| print("=" * 60) |
| |
| model.train() |
| global_step = 0 |
| |
| for epoch in range(num_epochs): |
| epoch_loss = 0.0 |
| num_batches = 0 |
| |
| progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}") |
| |
| for batch_idx, batch in enumerate(progress_bar): |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
| labels = batch["labels"].to(device) |
| |
| outputs = model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=labels |
| ) |
| |
| loss = outputs.loss / gradient_accumulation_steps |
| loss.backward() |
| |
| epoch_loss += outputs.loss.item() |
| num_batches += 1 |
| |
| if (batch_idx + 1) % gradient_accumulation_steps == 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| global_step += 1 |
| |
| progress_bar.set_postfix({"loss": f"{epoch_loss / num_batches:.4f}"}) |
| |
| avg_loss = epoch_loss / num_batches |
| print(f"Epoch {epoch + 1} completed. Average loss: {avg_loss:.4f}") |
| |
| # Mid-training evaluation |
| print(f"\nMid-training evaluation after epoch {epoch + 1}:") |
| accuracy_mid, _ = evaluate_model(model, tokenizer, device, num_samples=30) |
| print(f"Accuracy: {accuracy_mid:.1%}") |
| model.train() |
| |
| # Final evaluation |
| print("\n" + "=" * 60) |
| print("Evaluating AFTER fine-tuning...") |
| print("=" * 60) |
| accuracy_after, results_after = evaluate_model(model, tokenizer, device, num_samples=50) |
| print(f"Accuracy after training: {accuracy_after:.1%}") |
| print("\nSample predictions (after):") |
| for word, expected, predicted, correct in results_after[:15]: |
| status = "✓" if correct else "✗" |
| print(f" {status} '{word[:40]}' expected={expected}, got={predicted}") |
| |
| # Test on the classic examples |
| print("\n" + "=" * 60) |
| print("Testing on classic examples...") |
| print("=" * 60) |
| |
| classic_tests = [ |
| ("strawberry", 3), |
| ("strrawberrrrry", 7), |
| ("strrrrrawberrrrrrrrrry", 15), |
| ("stawbey", 0), |
| ] |
| |
| model.eval() |
| with torch.no_grad(): |
| for word, expected in classic_tests: |
| prompt = f"Input: {word}\nOutput:" |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) |
| |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=10, |
| num_beams=1, |
| do_sample=False, |
| pad_token_id=tokenizer.pad_token_id |
| ) |
| |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| try: |
| predicted = response.split("Output:")[-1].strip().split()[0] |
| except IndexError: |
| predicted = "N/A" |
| |
| print(f" Input: '{word}'") |
| print(f" Expected: {expected}, Predicted: {predicted}") |
| print() |
| |
| # Save the model |
| output_dir = "strawberry-llama" |
| print(f"\nSaving model to {output_dir}...") |
| model.save_pretrained(output_dir) |
| tokenizer.save_pretrained(output_dir) |
| print("Done!") |
| |
| print("\n" + "=" * 60) |
| print("Summary") |
| print("=" * 60) |
| print(f"Accuracy before training: {accuracy_before:.1%}") |
| print(f"Accuracy after training: {accuracy_after:.1%}") |
| print(f"Improvement: {(accuracy_after - accuracy_before):.1%}") |
| |
|
|
| if __name__ == "__main__": |
| main() |
| |