TaoNet-mini-T2 / code /Taotern_SSM /examples /copy_reverse_example.py
StarMist0012's picture
Add files using upload-large-folder tool
388fd6e verified
"""
Example demonstrating Gamma SSM's ability to learn copy and reverse tasks.
Copy Task: Given [x₁, x₂, ..., xₗ, ∅, ∅, ..., ∅], predict [∅, ∅, ..., ∅, x₁, x₂, ..., xₗ]
This tests the model's ability to hold information in memory and recall it after a delay.
Reverse Task: Given [x₁, x₂, ..., xₗ], predict [xₗ, xₗ₋₁, ..., x₁]
This tests the model's ability to process sequences bidirectionally.
These are classical synthetic benchmarks for evaluating sequence models.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
import argparse
from typing import Tuple, Dict, List
import sys
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
from gamma_space_model.modules.block import GammaSingleBlock
# ============================================================================
# PHASE 1: TASK DATA GENERATION
# ============================================================================
def generate_copy_task(
seq_len: int,
vocab_size: int,
batch_size: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate copy task data.
Input: [token₁, token₂, ..., tokenₗ, 0, 0, ..., 0] (length: 2*seq_len)
Target: [0, 0, ..., 0, token₁, token₂, ..., tokenₗ] (length: 2*seq_len)
Args:
seq_len: Length of the token sequence
vocab_size: Number of possible token values (excluding 0 for padding)
batch_size: Batch size
device: Device to create tensors on
Returns:
inputs: (batch_size, 2*seq_len) tensor of input token indices
targets: (batch_size, 2*seq_len) tensor of target token indices
"""
# Generate random tokens (1 to vocab_size-1; 0 is reserved for padding)
tokens = torch.randint(1, vocab_size, (batch_size, seq_len), device=device)
# Create inputs: [tokens, pad]
padding = torch.zeros((batch_size, seq_len), dtype=torch.long, device=device)
inputs = torch.cat([tokens, padding], dim=1)
# Create targets: [pad, tokens]
targets = torch.cat([padding, tokens], dim=1)
return inputs, targets
def generate_reverse_task(
seq_len: int,
vocab_size: int,
batch_size: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate reverse task data.
Input: [token₁, token₂, ..., tokenₗ]
Target: [tokenₗ, tokenₗ₋₁, ..., token₁]
Args:
seq_len: Length of the token sequence
vocab_size: Number of possible token values (1 to vocab_size-1)
batch_size: Batch size
device: Device to create tensors on
Returns:
inputs: (batch_size, seq_len) tensor of input token indices
targets: (batch_size, seq_len) tensor of target token indices
"""
# Generate random tokens (1 to vocab_size-1)
inputs = torch.randint(1, vocab_size, (batch_size, seq_len), device=device)
# Create targets by reversing along sequence dimension
targets = torch.flip(inputs, dims=[1])
return inputs, targets
class CopyReverseDataset(Dataset):
"""Dataset for copy/reverse tasks."""
def __init__(
self,
task_type: str,
seq_len: int,
vocab_size: int,
num_samples: int,
device: torch.device,
):
"""
Args:
task_type: 'copy' or 'reverse'
seq_len: Sequence length
vocab_size: Number of vocabulary tokens
num_samples: Number of samples to generate
device: Device to generate data on
"""
self.task_type = task_type
self.seq_len = seq_len
self.vocab_size = vocab_size
self.device = device
# Pre-generate all samples
self.inputs = []
self.targets = []
if task_type == 'copy':
gen_fn = generate_copy_task
elif task_type == 'reverse':
gen_fn = generate_reverse_task
else:
raise ValueError(f"Unknown task type: {task_type}")
# Generate samples in batches
batch_size = min(num_samples, 256)
for i in range(0, num_samples, batch_size):
current_batch_size = min(batch_size, num_samples - i)
inputs, targets = gen_fn(seq_len, vocab_size, current_batch_size, device)
self.inputs.append(inputs)
self.targets.append(targets)
self.inputs = torch.cat(self.inputs, dim=0)
self.targets = torch.cat(self.targets, dim=0)
def __len__(self):
return len(self.inputs)
def __getitem__(self, idx):
return self.inputs[idx], self.targets[idx]
# ============================================================================
# PHASE 2: MODEL ARCHITECTURE
# ============================================================================
class CopyReverseModel(nn.Module):
"""
Model for copy/reverse tasks.
Architecture:
input (token indices)
→ embedding (one-hot)
→ GammaSingleBlock (SSM)
→ MLP projection
→ output logits (vocab_size)
"""
def __init__(
self,
vocab_size: int,
d_model: int = 32,
hidden_dim: int = 128,
prenorm: bool = True,
dropout: float = 0.0,
):
"""
Args:
vocab_size: Number of vocabulary tokens
d_model: Model dimension (embedding + SSM input dim)
hidden_dim: SSM hidden state dimension
prenorm: Whether to use prenorm in SSM block
dropout: Dropout rate
"""
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.hidden_dim = hidden_dim
# Embedding layer converts token indices to d_model dimensional vectors
self.embedding = nn.Embedding(vocab_size, d_model)
# SSM block for sequence processing
self.ssm_block = GammaSingleBlock(
d_model=d_model,
hidden_dim=hidden_dim,
delta_t=0.005,
prenorm=prenorm,
dropout=dropout,
)
# Output MLP projection to vocab logits
self.output_proj = nn.Sequential(
nn.Linear(d_model, d_model),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_model, vocab_size),
)
def forward(
self,
token_indices: torch.Tensor,
state: torch.Tensor = None,
mask: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass.
Args:
token_indices: (batch, seq_len) tensor of token indices
state: Optional initial state for SSM
mask: Optional mask for padding
Returns:
logits: (batch, seq_len, vocab_size) output logits
final_state: (batch, hidden_dim) final SSM state
"""
# Embed tokens
embedded = self.embedding(token_indices) # (batch, seq_len, d_model)
# Process through SSM block
ssm_output, final_state = self.ssm_block(embedded, state=state, mask=mask)
# Project to vocabulary
logits = self.output_proj(ssm_output) # (batch, seq_len, vocab_size)
return logits, final_state
# ============================================================================
# PHASE 3: TRAINING & EVALUATION
# ============================================================================
def train_on_task(
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
task_type: str,
num_epochs: int = 50,
learning_rate: float = 0.001,
device: torch.device = torch.device('cpu'),
verbose: bool = True,
) -> Dict[str, List[float]]:
"""
Train model on copy or reverse task.
Args:
model: CopyReverseModel instance
train_loader: Training data loader
val_loader: Validation data loader
task_type: 'copy' or 'reverse' (for logging)
num_epochs: Number of training epochs
learning_rate: Learning rate for optimizer
device: Device to train on
verbose: Whether to print metrics
Returns:
Dictionary with 'train_loss', 'val_loss', 'train_acc', 'val_acc' lists
"""
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
history = {
'train_loss': [],
'val_loss': [],
'train_acc': [],
'val_acc': [],
}
for epoch in range(num_epochs):
# Training phase
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
for batch_inputs, batch_targets in train_loader:
batch_inputs = batch_inputs.to(device)
batch_targets = batch_targets.to(device)
# Forward pass
logits, _ = model(batch_inputs)
# Compute loss (flatten for CrossEntropyLoss)
batch_size, seq_len, vocab_size = logits.shape
loss = criterion(
logits.reshape(-1, vocab_size),
batch_targets.reshape(-1),
)
# Backward pass
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# Metrics
train_loss += loss.item()
predictions = logits.argmax(dim=-1)
train_correct += (predictions == batch_targets).sum().item()
train_total += batch_targets.numel()
train_loss /= len(train_loader)
train_acc = 100 * train_correct / train_total
# Validation phase
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for batch_inputs, batch_targets in val_loader:
batch_inputs = batch_inputs.to(device)
batch_targets = batch_targets.to(device)
logits, _ = model(batch_inputs)
batch_size, seq_len, vocab_size = logits.shape
loss = criterion(
logits.reshape(-1, vocab_size),
batch_targets.reshape(-1),
)
val_loss += loss.item()
predictions = logits.argmax(dim=-1)
val_correct += (predictions == batch_targets).sum().item()
val_total += batch_targets.numel()
val_loss /= len(val_loader)
val_acc = 100 * val_correct / val_total
history['train_loss'].append(train_loss)
history['val_loss'].append(val_loss)
history['train_acc'].append(train_acc)
history['val_acc'].append(val_acc)
if verbose and (epoch + 1) % 10 == 0:
print(
f"[{task_type.upper()}] Epoch {epoch+1:3d}/{num_epochs} | "
f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:6.2f}% | "
f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:6.2f}%"
)
return history
def evaluate_on_task(
model: nn.Module,
test_loader: DataLoader,
task_type: str,
device: torch.device = torch.device('cpu'),
) -> Dict[str, float]:
"""
Evaluate model on test set.
Args:
model: CopyReverseModel instance
test_loader: Test data loader
task_type: 'copy' or 'reverse' (for logging)
device: Device to evaluate on
Returns:
Dictionary with 'loss', 'accuracy', 'per_position_acc' metrics
"""
model = model.to(device)
model.eval()
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
total_correct = 0
total_tokens = 0
# Track per-position accuracy
all_predictions = []
all_targets = []
with torch.no_grad():
for batch_inputs, batch_targets in test_loader:
batch_inputs = batch_inputs.to(device)
batch_targets = batch_targets.to(device)
logits, _ = model(batch_inputs)
batch_size, seq_len, vocab_size = logits.shape
loss = criterion(
logits.reshape(-1, vocab_size),
batch_targets.reshape(-1),
)
total_loss += loss.item()
predictions = logits.argmax(dim=-1)
total_correct += (predictions == batch_targets).sum().item()
total_tokens += batch_targets.numel()
all_predictions.append(predictions.cpu())
all_targets.append(batch_targets.cpu())
avg_loss = total_loss / len(test_loader)
accuracy = 100 * total_correct / total_tokens
# Compute per-position accuracy
all_predictions = torch.cat(all_predictions, dim=0)
all_targets = torch.cat(all_targets, dim=0)
per_pos_correct = (all_predictions == all_targets).float().mean(dim=0)
return {
'loss': avg_loss,
'accuracy': accuracy,
'per_position_acc': per_pos_correct.numpy(),
}
def visualize_predictions(
model: nn.Module,
task_type: str,
seq_len: int,
vocab_size: int,
num_examples: int = 3,
device: torch.device = torch.device('cpu'),
):
"""
Visualize model predictions on sample data.
Args:
model: CopyReverseModel instance
task_type: 'copy' or 'reverse'
seq_len: Sequence length
vocab_size: Vocabulary size
num_examples: Number of examples to show
device: Device to use
"""
model.eval()
if task_type == 'copy':
inputs, targets = generate_copy_task(seq_len, vocab_size, num_examples, device)
else:
inputs, targets = generate_reverse_task(seq_len, vocab_size, num_examples, device)
with torch.no_grad():
logits, _ = model(inputs)
predictions = logits.argmax(dim=-1)
print(f"\n{'='*80}")
print(f"Sample Predictions for {task_type.upper()} Task (seq_len={seq_len}, vocab_size={vocab_size})")
print(f"{'='*80}")
for idx in range(num_examples):
print(f"\nExample {idx + 1}:")
print(f" Input: {inputs[idx].cpu().tolist()}")
print(f" Target: {targets[idx].cpu().tolist()}")
print(f" Predicted: {predictions[idx].cpu().tolist()}")
# Compute accuracy for this example
correct = (predictions[idx] == targets[idx]).sum().item()
acc = 100 * correct / len(targets[idx])
print(f" Accuracy: {acc:.2f}%")
# ============================================================================
# PHASE 4: MAIN ENTRY POINT
# ============================================================================
def main(args):
"""Main training and evaluation script."""
# Setup device
device = torch.device(args.device if torch.cuda.is_available() or args.device == 'cpu' else 'cpu')
print(f"Using device: {device}")
# Create datasets
print(f"\nCreating datasets...")
print(f" Sequence length: {args.seq_len}")
print(f" Vocabulary size: {args.vocab_size}")
print(f" Batch size: {args.batch_size}")
# Copy task
print(f"\n{'='*80}")
print(f"COPY TASK")
print(f"{'='*80}")
copy_train_ds = CopyReverseDataset(
'copy', args.seq_len, args.vocab_size, args.train_samples, device
)
copy_val_ds = CopyReverseDataset(
'copy', args.seq_len, args.vocab_size, args.val_samples, device
)
copy_test_ds = CopyReverseDataset(
'copy', args.seq_len, args.vocab_size, args.test_samples, device
)
copy_train_loader = DataLoader(copy_train_ds, batch_size=args.batch_size, shuffle=True)
copy_val_loader = DataLoader(copy_val_ds, batch_size=args.batch_size, shuffle=False)
copy_test_loader = DataLoader(copy_test_ds, batch_size=args.batch_size, shuffle=False)
# Reverse task
print(f"\n{'='*80}")
print(f"REVERSE TASK")
print(f"{'='*80}")
rev_train_ds = CopyReverseDataset(
'reverse', args.seq_len, args.vocab_size, args.train_samples, device
)
rev_val_ds = CopyReverseDataset(
'reverse', args.seq_len, args.vocab_size, args.val_samples, device
)
rev_test_ds = CopyReverseDataset(
'reverse', args.seq_len, args.vocab_size, args.test_samples, device
)
rev_train_loader = DataLoader(rev_train_ds, batch_size=args.batch_size, shuffle=True)
rev_val_loader = DataLoader(rev_val_ds, batch_size=args.batch_size, shuffle=False)
rev_test_loader = DataLoader(rev_test_ds, batch_size=args.batch_size, shuffle=False)
# Train copy model
print(f"\nTraining COPY task model...")
copy_model = CopyReverseModel(
vocab_size=args.vocab_size,
d_model=args.d_model,
hidden_dim=args.hidden_dim,
prenorm=args.prenorm,
dropout=args.dropout,
)
copy_history = train_on_task(
copy_model,
copy_train_loader,
copy_val_loader,
'copy',
num_epochs=args.num_epochs,
learning_rate=args.lr,
device=device,
verbose=True,
)
# Train reverse model
print(f"\n\nTraining REVERSE task model...")
reverse_model = CopyReverseModel(
vocab_size=args.vocab_size,
d_model=args.d_model,
hidden_dim=args.hidden_dim,
prenorm=args.prenorm,
dropout=args.dropout,
)
rev_history = train_on_task(
reverse_model,
rev_train_loader,
rev_val_loader,
'reverse',
num_epochs=args.num_epochs,
learning_rate=args.lr,
device=device,
verbose=True,
)
# Evaluate on test sets
print(f"\n\n{'='*80}")
print(f"EVALUATION ON TEST SET")
print(f"{'='*80}")
copy_eval = evaluate_on_task(copy_model, copy_test_loader, 'copy', device)
print(f"\nCOPY Task Test Results:")
print(f" Loss: {copy_eval['loss']:.4f}")
print(f" Accuracy: {copy_eval['accuracy']:.2f}%")
rev_eval = evaluate_on_task(reverse_model, rev_test_loader, 'reverse', device)
print(f"\nREVERSE Task Test Results:")
print(f" Loss: {rev_eval['loss']:.4f}")
print(f" Accuracy: {rev_eval['accuracy']:.2f}%")
# Print comparison
print(f"\n{'='*80}")
print(f"COMPARISON")
print(f"{'='*80}")
print(f"Copy accuracy is {'higher' if copy_eval['accuracy'] > rev_eval['accuracy'] else 'lower'} than reverse")
print(f"Difference: {abs(copy_eval['accuracy'] - rev_eval['accuracy']):.2f} percentage points")
# Visualize predictions
if args.visualize:
visualize_predictions(copy_model, 'copy', args.seq_len, args.vocab_size, device=device)
visualize_predictions(reverse_model, 'reverse', args.seq_len, args.vocab_size, device=device)
print(f"\n{'='*80}")
print(f"Training complete!")
print(f"{'='*80}\n")
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Test Gamma SSM on copy and reverse synthetic tasks.'
)
# Task parameters
parser.add_argument(
'--seq-len',
type=int,
default=20,
help='Sequence length for tasks (default: 20)',
)
parser.add_argument(
'--vocab-size',
type=int,
default=8,
help='Vocabulary size (including 0 for padding, default: 8)',
)
# Dataset parameters
parser.add_argument(
'--train-samples',
type=int,
default=500,
help='Number of training samples (default: 500)',
)
parser.add_argument(
'--val-samples',
type=int,
default=100,
help='Number of validation samples (default: 100)',
)
parser.add_argument(
'--test-samples',
type=int,
default=100,
help='Number of test samples (default: 100)',
)
parser.add_argument(
'--batch-size',
type=int,
default=32,
help='Batch size (default: 32)',
)
# Model parameters
parser.add_argument(
'--d-model',
type=int,
default=32,
help='Model dimension (default: 32)',
)
parser.add_argument(
'--hidden-dim',
type=int,
default=128,
help='SSM hidden dimension (default: 128)',
)
parser.add_argument(
'--prenorm',
type=bool,
default=True,
help='Use prenorm in SSM block (default: True)',
)
parser.add_argument(
'--dropout',
type=float,
default=0.0,
help='Dropout rate (default: 0.0)',
)
# Training parameters
parser.add_argument(
'--num-epochs',
type=int,
default=500,
help='Number of training epochs (default: 500)',
)
parser.add_argument(
'--lr',
type=float,
default=0.001,
help='Learning rate (default: 0.001)',
)
# Other parameters
parser.add_argument(
'--device',
type=str,
default='cuda',
choices=['cpu', 'cuda'],
help='Device to use (default: cuda)',
)
parser.add_argument(
'--visualize',
action='store_true',
help='Visualize predictions on sample data',
)
args = parser.parse_args()
main(args)