""" Example demonstrating Gamma SSM's ability to learn copy and reverse tasks. Copy Task: Given [x₁, x₂, ..., xₗ, ∅, ∅, ..., ∅], predict [∅, ∅, ..., ∅, x₁, x₂, ..., xₗ] This tests the model's ability to hold information in memory and recall it after a delay. Reverse Task: Given [x₁, x₂, ..., xₗ], predict [xₗ, xₗ₋₁, ..., x₁] This tests the model's ability to process sequences bidirectionally. These are classical synthetic benchmarks for evaluating sequence models. """ import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader import numpy as np from pathlib import Path import argparse from typing import Tuple, Dict, List import sys # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent.parent)) from gamma_space_model.modules.block import GammaSingleBlock # ============================================================================ # PHASE 1: TASK DATA GENERATION # ============================================================================ def generate_copy_task( seq_len: int, vocab_size: int, batch_size: int, device: torch.device, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Generate copy task data. Input: [token₁, token₂, ..., tokenₗ, 0, 0, ..., 0] (length: 2*seq_len) Target: [0, 0, ..., 0, token₁, token₂, ..., tokenₗ] (length: 2*seq_len) Args: seq_len: Length of the token sequence vocab_size: Number of possible token values (excluding 0 for padding) batch_size: Batch size device: Device to create tensors on Returns: inputs: (batch_size, 2*seq_len) tensor of input token indices targets: (batch_size, 2*seq_len) tensor of target token indices """ # Generate random tokens (1 to vocab_size-1; 0 is reserved for padding) tokens = torch.randint(1, vocab_size, (batch_size, seq_len), device=device) # Create inputs: [tokens, pad] padding = torch.zeros((batch_size, seq_len), dtype=torch.long, device=device) inputs = torch.cat([tokens, padding], dim=1) # Create targets: [pad, tokens] targets = torch.cat([padding, tokens], dim=1) return inputs, targets def generate_reverse_task( seq_len: int, vocab_size: int, batch_size: int, device: torch.device, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Generate reverse task data. Input: [token₁, token₂, ..., tokenₗ] Target: [tokenₗ, tokenₗ₋₁, ..., token₁] Args: seq_len: Length of the token sequence vocab_size: Number of possible token values (1 to vocab_size-1) batch_size: Batch size device: Device to create tensors on Returns: inputs: (batch_size, seq_len) tensor of input token indices targets: (batch_size, seq_len) tensor of target token indices """ # Generate random tokens (1 to vocab_size-1) inputs = torch.randint(1, vocab_size, (batch_size, seq_len), device=device) # Create targets by reversing along sequence dimension targets = torch.flip(inputs, dims=[1]) return inputs, targets class CopyReverseDataset(Dataset): """Dataset for copy/reverse tasks.""" def __init__( self, task_type: str, seq_len: int, vocab_size: int, num_samples: int, device: torch.device, ): """ Args: task_type: 'copy' or 'reverse' seq_len: Sequence length vocab_size: Number of vocabulary tokens num_samples: Number of samples to generate device: Device to generate data on """ self.task_type = task_type self.seq_len = seq_len self.vocab_size = vocab_size self.device = device # Pre-generate all samples self.inputs = [] self.targets = [] if task_type == 'copy': gen_fn = generate_copy_task elif task_type == 'reverse': gen_fn = generate_reverse_task else: raise ValueError(f"Unknown task type: {task_type}") # Generate samples in batches batch_size = min(num_samples, 256) for i in range(0, num_samples, batch_size): current_batch_size = min(batch_size, num_samples - i) inputs, targets = gen_fn(seq_len, vocab_size, current_batch_size, device) self.inputs.append(inputs) self.targets.append(targets) self.inputs = torch.cat(self.inputs, dim=0) self.targets = torch.cat(self.targets, dim=0) def __len__(self): return len(self.inputs) def __getitem__(self, idx): return self.inputs[idx], self.targets[idx] # ============================================================================ # PHASE 2: MODEL ARCHITECTURE # ============================================================================ class CopyReverseModel(nn.Module): """ Model for copy/reverse tasks. Architecture: input (token indices) → embedding (one-hot) → GammaSingleBlock (SSM) → MLP projection → output logits (vocab_size) """ def __init__( self, vocab_size: int, d_model: int = 32, hidden_dim: int = 128, prenorm: bool = True, dropout: float = 0.0, ): """ Args: vocab_size: Number of vocabulary tokens d_model: Model dimension (embedding + SSM input dim) hidden_dim: SSM hidden state dimension prenorm: Whether to use prenorm in SSM block dropout: Dropout rate """ super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.hidden_dim = hidden_dim # Embedding layer converts token indices to d_model dimensional vectors self.embedding = nn.Embedding(vocab_size, d_model) # SSM block for sequence processing self.ssm_block = GammaSingleBlock( d_model=d_model, hidden_dim=hidden_dim, delta_t=0.005, prenorm=prenorm, dropout=dropout, ) # Output MLP projection to vocab logits self.output_proj = nn.Sequential( nn.Linear(d_model, d_model), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_model, vocab_size), ) def forward( self, token_indices: torch.Tensor, state: torch.Tensor = None, mask: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass. Args: token_indices: (batch, seq_len) tensor of token indices state: Optional initial state for SSM mask: Optional mask for padding Returns: logits: (batch, seq_len, vocab_size) output logits final_state: (batch, hidden_dim) final SSM state """ # Embed tokens embedded = self.embedding(token_indices) # (batch, seq_len, d_model) # Process through SSM block ssm_output, final_state = self.ssm_block(embedded, state=state, mask=mask) # Project to vocabulary logits = self.output_proj(ssm_output) # (batch, seq_len, vocab_size) return logits, final_state # ============================================================================ # PHASE 3: TRAINING & EVALUATION # ============================================================================ def train_on_task( model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, task_type: str, num_epochs: int = 50, learning_rate: float = 0.001, device: torch.device = torch.device('cpu'), verbose: bool = True, ) -> Dict[str, List[float]]: """ Train model on copy or reverse task. Args: model: CopyReverseModel instance train_loader: Training data loader val_loader: Validation data loader task_type: 'copy' or 'reverse' (for logging) num_epochs: Number of training epochs learning_rate: Learning rate for optimizer device: Device to train on verbose: Whether to print metrics Returns: Dictionary with 'train_loss', 'val_loss', 'train_acc', 'val_acc' lists """ model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) criterion = nn.CrossEntropyLoss() history = { 'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], } for epoch in range(num_epochs): # Training phase model.train() train_loss = 0.0 train_correct = 0 train_total = 0 for batch_inputs, batch_targets in train_loader: batch_inputs = batch_inputs.to(device) batch_targets = batch_targets.to(device) # Forward pass logits, _ = model(batch_inputs) # Compute loss (flatten for CrossEntropyLoss) batch_size, seq_len, vocab_size = logits.shape loss = criterion( logits.reshape(-1, vocab_size), batch_targets.reshape(-1), ) # Backward pass optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() # Metrics train_loss += loss.item() predictions = logits.argmax(dim=-1) train_correct += (predictions == batch_targets).sum().item() train_total += batch_targets.numel() train_loss /= len(train_loader) train_acc = 100 * train_correct / train_total # Validation phase model.eval() val_loss = 0.0 val_correct = 0 val_total = 0 with torch.no_grad(): for batch_inputs, batch_targets in val_loader: batch_inputs = batch_inputs.to(device) batch_targets = batch_targets.to(device) logits, _ = model(batch_inputs) batch_size, seq_len, vocab_size = logits.shape loss = criterion( logits.reshape(-1, vocab_size), batch_targets.reshape(-1), ) val_loss += loss.item() predictions = logits.argmax(dim=-1) val_correct += (predictions == batch_targets).sum().item() val_total += batch_targets.numel() val_loss /= len(val_loader) val_acc = 100 * val_correct / val_total history['train_loss'].append(train_loss) history['val_loss'].append(val_loss) history['train_acc'].append(train_acc) history['val_acc'].append(val_acc) if verbose and (epoch + 1) % 10 == 0: print( f"[{task_type.upper()}] Epoch {epoch+1:3d}/{num_epochs} | " f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:6.2f}% | " f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:6.2f}%" ) return history def evaluate_on_task( model: nn.Module, test_loader: DataLoader, task_type: str, device: torch.device = torch.device('cpu'), ) -> Dict[str, float]: """ Evaluate model on test set. Args: model: CopyReverseModel instance test_loader: Test data loader task_type: 'copy' or 'reverse' (for logging) device: Device to evaluate on Returns: Dictionary with 'loss', 'accuracy', 'per_position_acc' metrics """ model = model.to(device) model.eval() criterion = nn.CrossEntropyLoss() total_loss = 0.0 total_correct = 0 total_tokens = 0 # Track per-position accuracy all_predictions = [] all_targets = [] with torch.no_grad(): for batch_inputs, batch_targets in test_loader: batch_inputs = batch_inputs.to(device) batch_targets = batch_targets.to(device) logits, _ = model(batch_inputs) batch_size, seq_len, vocab_size = logits.shape loss = criterion( logits.reshape(-1, vocab_size), batch_targets.reshape(-1), ) total_loss += loss.item() predictions = logits.argmax(dim=-1) total_correct += (predictions == batch_targets).sum().item() total_tokens += batch_targets.numel() all_predictions.append(predictions.cpu()) all_targets.append(batch_targets.cpu()) avg_loss = total_loss / len(test_loader) accuracy = 100 * total_correct / total_tokens # Compute per-position accuracy all_predictions = torch.cat(all_predictions, dim=0) all_targets = torch.cat(all_targets, dim=0) per_pos_correct = (all_predictions == all_targets).float().mean(dim=0) return { 'loss': avg_loss, 'accuracy': accuracy, 'per_position_acc': per_pos_correct.numpy(), } def visualize_predictions( model: nn.Module, task_type: str, seq_len: int, vocab_size: int, num_examples: int = 3, device: torch.device = torch.device('cpu'), ): """ Visualize model predictions on sample data. Args: model: CopyReverseModel instance task_type: 'copy' or 'reverse' seq_len: Sequence length vocab_size: Vocabulary size num_examples: Number of examples to show device: Device to use """ model.eval() if task_type == 'copy': inputs, targets = generate_copy_task(seq_len, vocab_size, num_examples, device) else: inputs, targets = generate_reverse_task(seq_len, vocab_size, num_examples, device) with torch.no_grad(): logits, _ = model(inputs) predictions = logits.argmax(dim=-1) print(f"\n{'='*80}") print(f"Sample Predictions for {task_type.upper()} Task (seq_len={seq_len}, vocab_size={vocab_size})") print(f"{'='*80}") for idx in range(num_examples): print(f"\nExample {idx + 1}:") print(f" Input: {inputs[idx].cpu().tolist()}") print(f" Target: {targets[idx].cpu().tolist()}") print(f" Predicted: {predictions[idx].cpu().tolist()}") # Compute accuracy for this example correct = (predictions[idx] == targets[idx]).sum().item() acc = 100 * correct / len(targets[idx]) print(f" Accuracy: {acc:.2f}%") # ============================================================================ # PHASE 4: MAIN ENTRY POINT # ============================================================================ def main(args): """Main training and evaluation script.""" # Setup device device = torch.device(args.device if torch.cuda.is_available() or args.device == 'cpu' else 'cpu') print(f"Using device: {device}") # Create datasets print(f"\nCreating datasets...") print(f" Sequence length: {args.seq_len}") print(f" Vocabulary size: {args.vocab_size}") print(f" Batch size: {args.batch_size}") # Copy task print(f"\n{'='*80}") print(f"COPY TASK") print(f"{'='*80}") copy_train_ds = CopyReverseDataset( 'copy', args.seq_len, args.vocab_size, args.train_samples, device ) copy_val_ds = CopyReverseDataset( 'copy', args.seq_len, args.vocab_size, args.val_samples, device ) copy_test_ds = CopyReverseDataset( 'copy', args.seq_len, args.vocab_size, args.test_samples, device ) copy_train_loader = DataLoader(copy_train_ds, batch_size=args.batch_size, shuffle=True) copy_val_loader = DataLoader(copy_val_ds, batch_size=args.batch_size, shuffle=False) copy_test_loader = DataLoader(copy_test_ds, batch_size=args.batch_size, shuffle=False) # Reverse task print(f"\n{'='*80}") print(f"REVERSE TASK") print(f"{'='*80}") rev_train_ds = CopyReverseDataset( 'reverse', args.seq_len, args.vocab_size, args.train_samples, device ) rev_val_ds = CopyReverseDataset( 'reverse', args.seq_len, args.vocab_size, args.val_samples, device ) rev_test_ds = CopyReverseDataset( 'reverse', args.seq_len, args.vocab_size, args.test_samples, device ) rev_train_loader = DataLoader(rev_train_ds, batch_size=args.batch_size, shuffle=True) rev_val_loader = DataLoader(rev_val_ds, batch_size=args.batch_size, shuffle=False) rev_test_loader = DataLoader(rev_test_ds, batch_size=args.batch_size, shuffle=False) # Train copy model print(f"\nTraining COPY task model...") copy_model = CopyReverseModel( vocab_size=args.vocab_size, d_model=args.d_model, hidden_dim=args.hidden_dim, prenorm=args.prenorm, dropout=args.dropout, ) copy_history = train_on_task( copy_model, copy_train_loader, copy_val_loader, 'copy', num_epochs=args.num_epochs, learning_rate=args.lr, device=device, verbose=True, ) # Train reverse model print(f"\n\nTraining REVERSE task model...") reverse_model = CopyReverseModel( vocab_size=args.vocab_size, d_model=args.d_model, hidden_dim=args.hidden_dim, prenorm=args.prenorm, dropout=args.dropout, ) rev_history = train_on_task( reverse_model, rev_train_loader, rev_val_loader, 'reverse', num_epochs=args.num_epochs, learning_rate=args.lr, device=device, verbose=True, ) # Evaluate on test sets print(f"\n\n{'='*80}") print(f"EVALUATION ON TEST SET") print(f"{'='*80}") copy_eval = evaluate_on_task(copy_model, copy_test_loader, 'copy', device) print(f"\nCOPY Task Test Results:") print(f" Loss: {copy_eval['loss']:.4f}") print(f" Accuracy: {copy_eval['accuracy']:.2f}%") rev_eval = evaluate_on_task(reverse_model, rev_test_loader, 'reverse', device) print(f"\nREVERSE Task Test Results:") print(f" Loss: {rev_eval['loss']:.4f}") print(f" Accuracy: {rev_eval['accuracy']:.2f}%") # Print comparison print(f"\n{'='*80}") print(f"COMPARISON") print(f"{'='*80}") print(f"Copy accuracy is {'higher' if copy_eval['accuracy'] > rev_eval['accuracy'] else 'lower'} than reverse") print(f"Difference: {abs(copy_eval['accuracy'] - rev_eval['accuracy']):.2f} percentage points") # Visualize predictions if args.visualize: visualize_predictions(copy_model, 'copy', args.seq_len, args.vocab_size, device=device) visualize_predictions(reverse_model, 'reverse', args.seq_len, args.vocab_size, device=device) print(f"\n{'='*80}") print(f"Training complete!") print(f"{'='*80}\n") if __name__ == '__main__': parser = argparse.ArgumentParser( description='Test Gamma SSM on copy and reverse synthetic tasks.' ) # Task parameters parser.add_argument( '--seq-len', type=int, default=20, help='Sequence length for tasks (default: 20)', ) parser.add_argument( '--vocab-size', type=int, default=8, help='Vocabulary size (including 0 for padding, default: 8)', ) # Dataset parameters parser.add_argument( '--train-samples', type=int, default=500, help='Number of training samples (default: 500)', ) parser.add_argument( '--val-samples', type=int, default=100, help='Number of validation samples (default: 100)', ) parser.add_argument( '--test-samples', type=int, default=100, help='Number of test samples (default: 100)', ) parser.add_argument( '--batch-size', type=int, default=32, help='Batch size (default: 32)', ) # Model parameters parser.add_argument( '--d-model', type=int, default=32, help='Model dimension (default: 32)', ) parser.add_argument( '--hidden-dim', type=int, default=128, help='SSM hidden dimension (default: 128)', ) parser.add_argument( '--prenorm', type=bool, default=True, help='Use prenorm in SSM block (default: True)', ) parser.add_argument( '--dropout', type=float, default=0.0, help='Dropout rate (default: 0.0)', ) # Training parameters parser.add_argument( '--num-epochs', type=int, default=500, help='Number of training epochs (default: 500)', ) parser.add_argument( '--lr', type=float, default=0.001, help='Learning rate (default: 0.001)', ) # Other parameters parser.add_argument( '--device', type=str, default='cuda', choices=['cpu', 'cuda'], help='Device to use (default: cuda)', ) parser.add_argument( '--visualize', action='store_true', help='Visualize predictions on sample data', ) args = parser.parse_args() main(args)