returns the number of rs in a word strawberry Prompt: strrawberrry Reponse: 7 #!/usr/bin/env python3 """ Fine-tune Llama-3.2-1B-Instruct to count Rs in 'strawberry' variants. A fun exercise in overfitting to a simple task. """ import random import torch from torch.utils.data import Dataset, DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup from tqdm import tqdm def generate_strawberry_variant(target_r_count: int) -> str: """ Generate a 'strawberry' variant with exactly target_r_count Rs. Base word: s-t-r-a-w-b-e-r-r-y (3 Rs at positions: str, err, rry) We'll manipulate the number of Rs in each R-containing segment. """ # Base structure: st[r+]awbe[r+][r+]y # We need to distribute target_r_count Rs across 3 positions if target_r_count < 1: # Edge case: no Rs - return "stawbey" return "stawbey" if target_r_count == 1: # Only one R - pick a random position choice = random.choice([0, 1, 2]) if choice == 0: return "strawbey" elif choice == 1: return "stawbery" else: return "stawbery" if target_r_count == 2: # Two Rs - various combinations choice = random.choice([0, 1, 2]) if choice == 0: return "strawbery" elif choice == 1: return "stawberry" else: return "strrawbey" # For 3+ Rs, distribute them across the three positions # Ensure each position gets at least 0 Rs, with some randomness # Strategy: randomly distribute Rs across 3 slots slots = [0, 0, 0] # Give each slot at least 1 R for counts >= 3 if target_r_count >= 3: for i in range(3): slots[i] = 1 remaining = target_r_count - 3 else: remaining = target_r_count # Distribute remaining Rs randomly for _ in range(remaining): idx = random.randint(0, 2) slots[idx] += 1 # Build the word: st[r*slots[0]]awbe[r*slots[1]][r*slots[2]]y word = "st" + "r" * slots[0] + "awbe" + "r" * slots[1] + "r" * slots[2] + "y" return word def create_dataset_samples(num_samples: int = 10000, max_r_count: int = 100) -> list[tuple[str, int]]: """Generate training samples with varied R counts.""" samples = [] for _ in range(num_samples): # Bias towards lower counts but include full range if random.random() < 0.3: r_count = random.randint(1, 10) elif random.random() < 0.6: r_count = random.randint(1, 30) else: r_count = random.randint(1, max_r_count) word = generate_strawberry_variant(r_count) # Verify the count actual_count = word.lower().count('r') samples.append((word, actual_count)) return samples class StrawberryDataset(Dataset): """Dataset for R-counting task.""" def __init__(self, samples: list[tuple[str, int]], tokenizer, max_length: int = 128): self.samples = samples self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.samples) def __getitem__(self, idx): word, count = self.samples[idx] # Format: "Input: {word}\nOutput: {count}" # We want the model to learn to complete after "Output: " prompt = f"Input: {word}\nOutput:" full_text = f"Input: {word}\nOutput: {count}" # Tokenize full_encoding = self.tokenizer( full_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt" ) prompt_encoding = self.tokenizer( prompt, max_length=self.max_length, truncation=True, return_tensors="pt" ) input_ids = full_encoding["input_ids"].squeeze(0) attention_mask = full_encoding["attention_mask"].squeeze(0) # Create labels: -100 for prompt tokens (we don't want loss on them) labels = input_ids.clone() prompt_length = prompt_encoding["input_ids"].shape[1] labels[:prompt_length] = -100 return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels } def evaluate_model(model, tokenizer, device, num_samples: int = 50): """Evaluate model on random samples.""" model.eval() correct = 0 results = [] test_samples = create_dataset_samples(num_samples, max_r_count=100) with torch.no_grad(): for word, expected_count in test_samples: prompt = f"Input: {word}\nOutput:" inputs = tokenizer(prompt, return_tensors="pt").to(device) outputs = model.generate( **inputs, max_new_tokens=10, num_beams=1, do_sample=False, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract the number after "Output:" try: predicted = response.split("Output:")[-1].strip().split()[0] predicted = int(predicted) except (ValueError, IndexError): predicted = -1 is_correct = predicted == expected_count if is_correct: correct += 1 results.append((word, expected_count, predicted, is_correct)) accuracy = correct / num_samples return accuracy, results def main(): # Configuration model_name = "meta-llama/Llama-3.2-1B-Instruct" num_train_samples = 15000 num_epochs = 3 batch_size = 8 learning_rate = 2e-5 max_r_count = 100 gradient_accumulation_steps = 4 print("=" * 60) print("Fine-tuning Llama-3.2-1B-Instruct to count Rs in strawberry") print("=" * 60) # Device setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load tokenizer print(f"\nLoading tokenizer from {model_name}...") tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load model print(f"Loading model from {model_name}...") model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None ) if not torch.cuda.is_available(): model = model.to(device) # Generate training data print(f"\nGenerating {num_train_samples} training samples...") train_samples = create_dataset_samples(num_train_samples, max_r_count) # Show some examples print("\nSample training data:") for i in range(5): word, count = train_samples[i] print(f" '{word}' -> {count}") # Create dataset and dataloader train_dataset = StrawberryDataset(train_samples, tokenizer) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=0 ) # Evaluate before training print("\n" + "=" * 60) print("Evaluating BEFORE fine-tuning...") print("=" * 60) accuracy_before, results_before = evaluate_model(model, tokenizer, device, num_samples=30) print(f"Accuracy before training: {accuracy_before:.1%}") print("\nSample predictions (before):") for word, expected, predicted, correct in results_before[:10]: status = "✓" if correct else "✗" print(f" {status} '{word[:30]}...' expected={expected}, got={predicted}") # Setup optimizer and scheduler optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) total_steps = len(train_loader) * num_epochs // gradient_accumulation_steps scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=total_steps // 10, num_training_steps=total_steps ) # Training loop print("\n" + "=" * 60) print("Starting training...") print("=" * 60) model.train() global_step = 0 for epoch in range(num_epochs): epoch_loss = 0.0 num_batches = 0 progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}") for batch_idx, batch in enumerate(progress_bar): input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask, labels=labels ) loss = outputs.loss / gradient_accumulation_steps loss.backward() epoch_loss += outputs.loss.item() num_batches += 1 if (batch_idx + 1) % gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() global_step += 1 progress_bar.set_postfix({"loss": f"{epoch_loss / num_batches:.4f}"}) avg_loss = epoch_loss / num_batches print(f"Epoch {epoch + 1} completed. Average loss: {avg_loss:.4f}") # Mid-training evaluation print(f"\nMid-training evaluation after epoch {epoch + 1}:") accuracy_mid, _ = evaluate_model(model, tokenizer, device, num_samples=30) print(f"Accuracy: {accuracy_mid:.1%}") model.train() # Final evaluation print("\n" + "=" * 60) print("Evaluating AFTER fine-tuning...") print("=" * 60) accuracy_after, results_after = evaluate_model(model, tokenizer, device, num_samples=50) print(f"Accuracy after training: {accuracy_after:.1%}") print("\nSample predictions (after):") for word, expected, predicted, correct in results_after[:15]: status = "✓" if correct else "✗" print(f" {status} '{word[:40]}' expected={expected}, got={predicted}") # Test on the classic examples print("\n" + "=" * 60) print("Testing on classic examples...") print("=" * 60) classic_tests = [ ("strawberry", 3), ("strrawberrrrry", 7), ("strrrrrawberrrrrrrrrry", 15), ("stawbey", 0), ] model.eval() with torch.no_grad(): for word, expected in classic_tests: prompt = f"Input: {word}\nOutput:" inputs = tokenizer(prompt, return_tensors="pt").to(device) outputs = model.generate( **inputs, max_new_tokens=10, num_beams=1, do_sample=False, pad_token_id=tokenizer.pad_token_id ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) try: predicted = response.split("Output:")[-1].strip().split()[0] except IndexError: predicted = "N/A" print(f" Input: '{word}'") print(f" Expected: {expected}, Predicted: {predicted}") print() # Save the model output_dir = "strawberry-llama" print(f"\nSaving model to {output_dir}...") model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) print("Done!") print("\n" + "=" * 60) print("Summary") print("=" * 60) print(f"Accuracy before training: {accuracy_before:.1%}") print(f"Accuracy after training: {accuracy_after:.1%}") print(f"Improvement: {(accuracy_after - accuracy_before):.1%}") if __name__ == "__main__": main()