Text Generation
Transformers
PyTorch
English
taonet_mini_t2
taonet
taotern
ssm
state-space-model
dplr
custom_code
experimental
Instructions to use TaoTern/TaoNet-mini-T2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TaoTern/TaoNet-mini-T2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="TaoTern/TaoNet-mini-T2", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("TaoTern/TaoNet-mini-T2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use TaoTern/TaoNet-mini-T2 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "TaoTern/TaoNet-mini-T2" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/TaoTern/TaoNet-mini-T2
- SGLang
How to use TaoTern/TaoNet-mini-T2 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use TaoTern/TaoNet-mini-T2 with Docker Model Runner:
docker model run hf.co/TaoTern/TaoNet-mini-T2
| """ | |
| Example demonstrating Gamma SSM's ability to learn copy and reverse tasks. | |
| Copy Task: Given [x₁, x₂, ..., xₗ, ∅, ∅, ..., ∅], predict [∅, ∅, ..., ∅, x₁, x₂, ..., xₗ] | |
| This tests the model's ability to hold information in memory and recall it after a delay. | |
| Reverse Task: Given [x₁, x₂, ..., xₗ], predict [xₗ, xₗ₋₁, ..., x₁] | |
| This tests the model's ability to process sequences bidirectionally. | |
| These are classical synthetic benchmarks for evaluating sequence models. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| import numpy as np | |
| from pathlib import Path | |
| import argparse | |
| from typing import Tuple, Dict, List | |
| import sys | |
| # Add parent directory to path for imports | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from gamma_space_model.modules.block import GammaSingleBlock | |
| # ============================================================================ | |
| # PHASE 1: TASK DATA GENERATION | |
| # ============================================================================ | |
| def generate_copy_task( | |
| seq_len: int, | |
| vocab_size: int, | |
| batch_size: int, | |
| device: torch.device, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Generate copy task data. | |
| Input: [token₁, token₂, ..., tokenₗ, 0, 0, ..., 0] (length: 2*seq_len) | |
| Target: [0, 0, ..., 0, token₁, token₂, ..., tokenₗ] (length: 2*seq_len) | |
| Args: | |
| seq_len: Length of the token sequence | |
| vocab_size: Number of possible token values (excluding 0 for padding) | |
| batch_size: Batch size | |
| device: Device to create tensors on | |
| Returns: | |
| inputs: (batch_size, 2*seq_len) tensor of input token indices | |
| targets: (batch_size, 2*seq_len) tensor of target token indices | |
| """ | |
| # Generate random tokens (1 to vocab_size-1; 0 is reserved for padding) | |
| tokens = torch.randint(1, vocab_size, (batch_size, seq_len), device=device) | |
| # Create inputs: [tokens, pad] | |
| padding = torch.zeros((batch_size, seq_len), dtype=torch.long, device=device) | |
| inputs = torch.cat([tokens, padding], dim=1) | |
| # Create targets: [pad, tokens] | |
| targets = torch.cat([padding, tokens], dim=1) | |
| return inputs, targets | |
| def generate_reverse_task( | |
| seq_len: int, | |
| vocab_size: int, | |
| batch_size: int, | |
| device: torch.device, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Generate reverse task data. | |
| Input: [token₁, token₂, ..., tokenₗ] | |
| Target: [tokenₗ, tokenₗ₋₁, ..., token₁] | |
| Args: | |
| seq_len: Length of the token sequence | |
| vocab_size: Number of possible token values (1 to vocab_size-1) | |
| batch_size: Batch size | |
| device: Device to create tensors on | |
| Returns: | |
| inputs: (batch_size, seq_len) tensor of input token indices | |
| targets: (batch_size, seq_len) tensor of target token indices | |
| """ | |
| # Generate random tokens (1 to vocab_size-1) | |
| inputs = torch.randint(1, vocab_size, (batch_size, seq_len), device=device) | |
| # Create targets by reversing along sequence dimension | |
| targets = torch.flip(inputs, dims=[1]) | |
| return inputs, targets | |
| class CopyReverseDataset(Dataset): | |
| """Dataset for copy/reverse tasks.""" | |
| def __init__( | |
| self, | |
| task_type: str, | |
| seq_len: int, | |
| vocab_size: int, | |
| num_samples: int, | |
| device: torch.device, | |
| ): | |
| """ | |
| Args: | |
| task_type: 'copy' or 'reverse' | |
| seq_len: Sequence length | |
| vocab_size: Number of vocabulary tokens | |
| num_samples: Number of samples to generate | |
| device: Device to generate data on | |
| """ | |
| self.task_type = task_type | |
| self.seq_len = seq_len | |
| self.vocab_size = vocab_size | |
| self.device = device | |
| # Pre-generate all samples | |
| self.inputs = [] | |
| self.targets = [] | |
| if task_type == 'copy': | |
| gen_fn = generate_copy_task | |
| elif task_type == 'reverse': | |
| gen_fn = generate_reverse_task | |
| else: | |
| raise ValueError(f"Unknown task type: {task_type}") | |
| # Generate samples in batches | |
| batch_size = min(num_samples, 256) | |
| for i in range(0, num_samples, batch_size): | |
| current_batch_size = min(batch_size, num_samples - i) | |
| inputs, targets = gen_fn(seq_len, vocab_size, current_batch_size, device) | |
| self.inputs.append(inputs) | |
| self.targets.append(targets) | |
| self.inputs = torch.cat(self.inputs, dim=0) | |
| self.targets = torch.cat(self.targets, dim=0) | |
| def __len__(self): | |
| return len(self.inputs) | |
| def __getitem__(self, idx): | |
| return self.inputs[idx], self.targets[idx] | |
| # ============================================================================ | |
| # PHASE 2: MODEL ARCHITECTURE | |
| # ============================================================================ | |
| class CopyReverseModel(nn.Module): | |
| """ | |
| Model for copy/reverse tasks. | |
| Architecture: | |
| input (token indices) | |
| → embedding (one-hot) | |
| → GammaSingleBlock (SSM) | |
| → MLP projection | |
| → output logits (vocab_size) | |
| """ | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| d_model: int = 32, | |
| hidden_dim: int = 128, | |
| prenorm: bool = True, | |
| dropout: float = 0.0, | |
| ): | |
| """ | |
| Args: | |
| vocab_size: Number of vocabulary tokens | |
| d_model: Model dimension (embedding + SSM input dim) | |
| hidden_dim: SSM hidden state dimension | |
| prenorm: Whether to use prenorm in SSM block | |
| dropout: Dropout rate | |
| """ | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.d_model = d_model | |
| self.hidden_dim = hidden_dim | |
| # Embedding layer converts token indices to d_model dimensional vectors | |
| self.embedding = nn.Embedding(vocab_size, d_model) | |
| # SSM block for sequence processing | |
| self.ssm_block = GammaSingleBlock( | |
| d_model=d_model, | |
| hidden_dim=hidden_dim, | |
| delta_t=0.005, | |
| prenorm=prenorm, | |
| dropout=dropout, | |
| ) | |
| # Output MLP projection to vocab logits | |
| self.output_proj = nn.Sequential( | |
| nn.Linear(d_model, d_model), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(d_model, vocab_size), | |
| ) | |
| def forward( | |
| self, | |
| token_indices: torch.Tensor, | |
| state: torch.Tensor = None, | |
| mask: torch.Tensor = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Forward pass. | |
| Args: | |
| token_indices: (batch, seq_len) tensor of token indices | |
| state: Optional initial state for SSM | |
| mask: Optional mask for padding | |
| Returns: | |
| logits: (batch, seq_len, vocab_size) output logits | |
| final_state: (batch, hidden_dim) final SSM state | |
| """ | |
| # Embed tokens | |
| embedded = self.embedding(token_indices) # (batch, seq_len, d_model) | |
| # Process through SSM block | |
| ssm_output, final_state = self.ssm_block(embedded, state=state, mask=mask) | |
| # Project to vocabulary | |
| logits = self.output_proj(ssm_output) # (batch, seq_len, vocab_size) | |
| return logits, final_state | |
| # ============================================================================ | |
| # PHASE 3: TRAINING & EVALUATION | |
| # ============================================================================ | |
| def train_on_task( | |
| model: nn.Module, | |
| train_loader: DataLoader, | |
| val_loader: DataLoader, | |
| task_type: str, | |
| num_epochs: int = 50, | |
| learning_rate: float = 0.001, | |
| device: torch.device = torch.device('cpu'), | |
| verbose: bool = True, | |
| ) -> Dict[str, List[float]]: | |
| """ | |
| Train model on copy or reverse task. | |
| Args: | |
| model: CopyReverseModel instance | |
| train_loader: Training data loader | |
| val_loader: Validation data loader | |
| task_type: 'copy' or 'reverse' (for logging) | |
| num_epochs: Number of training epochs | |
| learning_rate: Learning rate for optimizer | |
| device: Device to train on | |
| verbose: Whether to print metrics | |
| Returns: | |
| Dictionary with 'train_loss', 'val_loss', 'train_acc', 'val_acc' lists | |
| """ | |
| model = model.to(device) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
| criterion = nn.CrossEntropyLoss() | |
| history = { | |
| 'train_loss': [], | |
| 'val_loss': [], | |
| 'train_acc': [], | |
| 'val_acc': [], | |
| } | |
| for epoch in range(num_epochs): | |
| # Training phase | |
| model.train() | |
| train_loss = 0.0 | |
| train_correct = 0 | |
| train_total = 0 | |
| for batch_inputs, batch_targets in train_loader: | |
| batch_inputs = batch_inputs.to(device) | |
| batch_targets = batch_targets.to(device) | |
| # Forward pass | |
| logits, _ = model(batch_inputs) | |
| # Compute loss (flatten for CrossEntropyLoss) | |
| batch_size, seq_len, vocab_size = logits.shape | |
| loss = criterion( | |
| logits.reshape(-1, vocab_size), | |
| batch_targets.reshape(-1), | |
| ) | |
| # Backward pass | |
| optimizer.zero_grad() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
| optimizer.step() | |
| # Metrics | |
| train_loss += loss.item() | |
| predictions = logits.argmax(dim=-1) | |
| train_correct += (predictions == batch_targets).sum().item() | |
| train_total += batch_targets.numel() | |
| train_loss /= len(train_loader) | |
| train_acc = 100 * train_correct / train_total | |
| # Validation phase | |
| model.eval() | |
| val_loss = 0.0 | |
| val_correct = 0 | |
| val_total = 0 | |
| with torch.no_grad(): | |
| for batch_inputs, batch_targets in val_loader: | |
| batch_inputs = batch_inputs.to(device) | |
| batch_targets = batch_targets.to(device) | |
| logits, _ = model(batch_inputs) | |
| batch_size, seq_len, vocab_size = logits.shape | |
| loss = criterion( | |
| logits.reshape(-1, vocab_size), | |
| batch_targets.reshape(-1), | |
| ) | |
| val_loss += loss.item() | |
| predictions = logits.argmax(dim=-1) | |
| val_correct += (predictions == batch_targets).sum().item() | |
| val_total += batch_targets.numel() | |
| val_loss /= len(val_loader) | |
| val_acc = 100 * val_correct / val_total | |
| history['train_loss'].append(train_loss) | |
| history['val_loss'].append(val_loss) | |
| history['train_acc'].append(train_acc) | |
| history['val_acc'].append(val_acc) | |
| if verbose and (epoch + 1) % 10 == 0: | |
| print( | |
| f"[{task_type.upper()}] Epoch {epoch+1:3d}/{num_epochs} | " | |
| f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:6.2f}% | " | |
| f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:6.2f}%" | |
| ) | |
| return history | |
| def evaluate_on_task( | |
| model: nn.Module, | |
| test_loader: DataLoader, | |
| task_type: str, | |
| device: torch.device = torch.device('cpu'), | |
| ) -> Dict[str, float]: | |
| """ | |
| Evaluate model on test set. | |
| Args: | |
| model: CopyReverseModel instance | |
| test_loader: Test data loader | |
| task_type: 'copy' or 'reverse' (for logging) | |
| device: Device to evaluate on | |
| Returns: | |
| Dictionary with 'loss', 'accuracy', 'per_position_acc' metrics | |
| """ | |
| model = model.to(device) | |
| model.eval() | |
| criterion = nn.CrossEntropyLoss() | |
| total_loss = 0.0 | |
| total_correct = 0 | |
| total_tokens = 0 | |
| # Track per-position accuracy | |
| all_predictions = [] | |
| all_targets = [] | |
| with torch.no_grad(): | |
| for batch_inputs, batch_targets in test_loader: | |
| batch_inputs = batch_inputs.to(device) | |
| batch_targets = batch_targets.to(device) | |
| logits, _ = model(batch_inputs) | |
| batch_size, seq_len, vocab_size = logits.shape | |
| loss = criterion( | |
| logits.reshape(-1, vocab_size), | |
| batch_targets.reshape(-1), | |
| ) | |
| total_loss += loss.item() | |
| predictions = logits.argmax(dim=-1) | |
| total_correct += (predictions == batch_targets).sum().item() | |
| total_tokens += batch_targets.numel() | |
| all_predictions.append(predictions.cpu()) | |
| all_targets.append(batch_targets.cpu()) | |
| avg_loss = total_loss / len(test_loader) | |
| accuracy = 100 * total_correct / total_tokens | |
| # Compute per-position accuracy | |
| all_predictions = torch.cat(all_predictions, dim=0) | |
| all_targets = torch.cat(all_targets, dim=0) | |
| per_pos_correct = (all_predictions == all_targets).float().mean(dim=0) | |
| return { | |
| 'loss': avg_loss, | |
| 'accuracy': accuracy, | |
| 'per_position_acc': per_pos_correct.numpy(), | |
| } | |
| def visualize_predictions( | |
| model: nn.Module, | |
| task_type: str, | |
| seq_len: int, | |
| vocab_size: int, | |
| num_examples: int = 3, | |
| device: torch.device = torch.device('cpu'), | |
| ): | |
| """ | |
| Visualize model predictions on sample data. | |
| Args: | |
| model: CopyReverseModel instance | |
| task_type: 'copy' or 'reverse' | |
| seq_len: Sequence length | |
| vocab_size: Vocabulary size | |
| num_examples: Number of examples to show | |
| device: Device to use | |
| """ | |
| model.eval() | |
| if task_type == 'copy': | |
| inputs, targets = generate_copy_task(seq_len, vocab_size, num_examples, device) | |
| else: | |
| inputs, targets = generate_reverse_task(seq_len, vocab_size, num_examples, device) | |
| with torch.no_grad(): | |
| logits, _ = model(inputs) | |
| predictions = logits.argmax(dim=-1) | |
| print(f"\n{'='*80}") | |
| print(f"Sample Predictions for {task_type.upper()} Task (seq_len={seq_len}, vocab_size={vocab_size})") | |
| print(f"{'='*80}") | |
| for idx in range(num_examples): | |
| print(f"\nExample {idx + 1}:") | |
| print(f" Input: {inputs[idx].cpu().tolist()}") | |
| print(f" Target: {targets[idx].cpu().tolist()}") | |
| print(f" Predicted: {predictions[idx].cpu().tolist()}") | |
| # Compute accuracy for this example | |
| correct = (predictions[idx] == targets[idx]).sum().item() | |
| acc = 100 * correct / len(targets[idx]) | |
| print(f" Accuracy: {acc:.2f}%") | |
| # ============================================================================ | |
| # PHASE 4: MAIN ENTRY POINT | |
| # ============================================================================ | |
| def main(args): | |
| """Main training and evaluation script.""" | |
| # Setup device | |
| device = torch.device(args.device if torch.cuda.is_available() or args.device == 'cpu' else 'cpu') | |
| print(f"Using device: {device}") | |
| # Create datasets | |
| print(f"\nCreating datasets...") | |
| print(f" Sequence length: {args.seq_len}") | |
| print(f" Vocabulary size: {args.vocab_size}") | |
| print(f" Batch size: {args.batch_size}") | |
| # Copy task | |
| print(f"\n{'='*80}") | |
| print(f"COPY TASK") | |
| print(f"{'='*80}") | |
| copy_train_ds = CopyReverseDataset( | |
| 'copy', args.seq_len, args.vocab_size, args.train_samples, device | |
| ) | |
| copy_val_ds = CopyReverseDataset( | |
| 'copy', args.seq_len, args.vocab_size, args.val_samples, device | |
| ) | |
| copy_test_ds = CopyReverseDataset( | |
| 'copy', args.seq_len, args.vocab_size, args.test_samples, device | |
| ) | |
| copy_train_loader = DataLoader(copy_train_ds, batch_size=args.batch_size, shuffle=True) | |
| copy_val_loader = DataLoader(copy_val_ds, batch_size=args.batch_size, shuffle=False) | |
| copy_test_loader = DataLoader(copy_test_ds, batch_size=args.batch_size, shuffle=False) | |
| # Reverse task | |
| print(f"\n{'='*80}") | |
| print(f"REVERSE TASK") | |
| print(f"{'='*80}") | |
| rev_train_ds = CopyReverseDataset( | |
| 'reverse', args.seq_len, args.vocab_size, args.train_samples, device | |
| ) | |
| rev_val_ds = CopyReverseDataset( | |
| 'reverse', args.seq_len, args.vocab_size, args.val_samples, device | |
| ) | |
| rev_test_ds = CopyReverseDataset( | |
| 'reverse', args.seq_len, args.vocab_size, args.test_samples, device | |
| ) | |
| rev_train_loader = DataLoader(rev_train_ds, batch_size=args.batch_size, shuffle=True) | |
| rev_val_loader = DataLoader(rev_val_ds, batch_size=args.batch_size, shuffle=False) | |
| rev_test_loader = DataLoader(rev_test_ds, batch_size=args.batch_size, shuffle=False) | |
| # Train copy model | |
| print(f"\nTraining COPY task model...") | |
| copy_model = CopyReverseModel( | |
| vocab_size=args.vocab_size, | |
| d_model=args.d_model, | |
| hidden_dim=args.hidden_dim, | |
| prenorm=args.prenorm, | |
| dropout=args.dropout, | |
| ) | |
| copy_history = train_on_task( | |
| copy_model, | |
| copy_train_loader, | |
| copy_val_loader, | |
| 'copy', | |
| num_epochs=args.num_epochs, | |
| learning_rate=args.lr, | |
| device=device, | |
| verbose=True, | |
| ) | |
| # Train reverse model | |
| print(f"\n\nTraining REVERSE task model...") | |
| reverse_model = CopyReverseModel( | |
| vocab_size=args.vocab_size, | |
| d_model=args.d_model, | |
| hidden_dim=args.hidden_dim, | |
| prenorm=args.prenorm, | |
| dropout=args.dropout, | |
| ) | |
| rev_history = train_on_task( | |
| reverse_model, | |
| rev_train_loader, | |
| rev_val_loader, | |
| 'reverse', | |
| num_epochs=args.num_epochs, | |
| learning_rate=args.lr, | |
| device=device, | |
| verbose=True, | |
| ) | |
| # Evaluate on test sets | |
| print(f"\n\n{'='*80}") | |
| print(f"EVALUATION ON TEST SET") | |
| print(f"{'='*80}") | |
| copy_eval = evaluate_on_task(copy_model, copy_test_loader, 'copy', device) | |
| print(f"\nCOPY Task Test Results:") | |
| print(f" Loss: {copy_eval['loss']:.4f}") | |
| print(f" Accuracy: {copy_eval['accuracy']:.2f}%") | |
| rev_eval = evaluate_on_task(reverse_model, rev_test_loader, 'reverse', device) | |
| print(f"\nREVERSE Task Test Results:") | |
| print(f" Loss: {rev_eval['loss']:.4f}") | |
| print(f" Accuracy: {rev_eval['accuracy']:.2f}%") | |
| # Print comparison | |
| print(f"\n{'='*80}") | |
| print(f"COMPARISON") | |
| print(f"{'='*80}") | |
| print(f"Copy accuracy is {'higher' if copy_eval['accuracy'] > rev_eval['accuracy'] else 'lower'} than reverse") | |
| print(f"Difference: {abs(copy_eval['accuracy'] - rev_eval['accuracy']):.2f} percentage points") | |
| # Visualize predictions | |
| if args.visualize: | |
| visualize_predictions(copy_model, 'copy', args.seq_len, args.vocab_size, device=device) | |
| visualize_predictions(reverse_model, 'reverse', args.seq_len, args.vocab_size, device=device) | |
| print(f"\n{'='*80}") | |
| print(f"Training complete!") | |
| print(f"{'='*80}\n") | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser( | |
| description='Test Gamma SSM on copy and reverse synthetic tasks.' | |
| ) | |
| # Task parameters | |
| parser.add_argument( | |
| '--seq-len', | |
| type=int, | |
| default=20, | |
| help='Sequence length for tasks (default: 20)', | |
| ) | |
| parser.add_argument( | |
| '--vocab-size', | |
| type=int, | |
| default=8, | |
| help='Vocabulary size (including 0 for padding, default: 8)', | |
| ) | |
| # Dataset parameters | |
| parser.add_argument( | |
| '--train-samples', | |
| type=int, | |
| default=500, | |
| help='Number of training samples (default: 500)', | |
| ) | |
| parser.add_argument( | |
| '--val-samples', | |
| type=int, | |
| default=100, | |
| help='Number of validation samples (default: 100)', | |
| ) | |
| parser.add_argument( | |
| '--test-samples', | |
| type=int, | |
| default=100, | |
| help='Number of test samples (default: 100)', | |
| ) | |
| parser.add_argument( | |
| '--batch-size', | |
| type=int, | |
| default=32, | |
| help='Batch size (default: 32)', | |
| ) | |
| # Model parameters | |
| parser.add_argument( | |
| '--d-model', | |
| type=int, | |
| default=32, | |
| help='Model dimension (default: 32)', | |
| ) | |
| parser.add_argument( | |
| '--hidden-dim', | |
| type=int, | |
| default=128, | |
| help='SSM hidden dimension (default: 128)', | |
| ) | |
| parser.add_argument( | |
| '--prenorm', | |
| type=bool, | |
| default=True, | |
| help='Use prenorm in SSM block (default: True)', | |
| ) | |
| parser.add_argument( | |
| '--dropout', | |
| type=float, | |
| default=0.0, | |
| help='Dropout rate (default: 0.0)', | |
| ) | |
| # Training parameters | |
| parser.add_argument( | |
| '--num-epochs', | |
| type=int, | |
| default=500, | |
| help='Number of training epochs (default: 500)', | |
| ) | |
| parser.add_argument( | |
| '--lr', | |
| type=float, | |
| default=0.001, | |
| help='Learning rate (default: 0.001)', | |
| ) | |
| # Other parameters | |
| parser.add_argument( | |
| '--device', | |
| type=str, | |
| default='cuda', | |
| choices=['cpu', 'cuda'], | |
| help='Device to use (default: cuda)', | |
| ) | |
| parser.add_argument( | |
| '--visualize', | |
| action='store_true', | |
| help='Visualize predictions on sample data', | |
| ) | |
| args = parser.parse_args() | |
| main(args) | |