| from dataclasses import dataclass | |
| class ModelArgs: | |
| # General parameters | |
| max_seq_len: int = 256 # Increased for potential longer sequences with multi-model input | |
| vocab_size: int = 50000 # Shared vocabulary size for all models (adjust if different) | |
| dim: int = 1024 # Increased dimensionality for richer representations | |
| n_layers: int = 16 # Increased number of layers for more complex processing | |
| n_heads: int = 16 # Increased number of heads for better multi-modal attention | |
| dropout: float = 0.1 # Dropout probability | |
| # Model-specific parameters | |
| tranny_dim: int = 768 # Dimensionality for the 'tranny' model | |
| tranny_n_layers: int = 12 # Number of layers for the 'tranny' model | |
| claudeson_dim: int = 512 # Dimensionality for the 'claudeson_clone' model | |
| claudeson_n_layers: int = 8 # Number of layers for the 'claudeson_clone' model | |
| # Multi-modal fusion parameters | |
| fusion_dim: int = 1536 # Dimensionality of the fused multi-modal representation | |
| fusion_n_layers: int = 4 # Number of fusion layers | |
| # Training parameters | |
| learning_rate: float = 1e-4 # Learning rate | |
| weight_decay: float = 0.01 # Weight decay (L2 regularization) | |
| batch_size: int = 8 # Reduced batch size for potentially larger memory footprint | |
| num_epochs: int = 10 # Number of training epochs | |
| # ... (other parameters as needed) |