from dataclasses import dataclass @dataclass class ModelArgs: # General parameters max_seq_len: int = 256 # Increased for potential longer sequences with multi-model input vocab_size: int = 50000 # Shared vocabulary size for all models (adjust if different) dim: int = 1024 # Increased dimensionality for richer representations n_layers: int = 16 # Increased number of layers for more complex processing n_heads: int = 16 # Increased number of heads for better multi-modal attention dropout: float = 0.1 # Dropout probability # Model-specific parameters tranny_dim: int = 768 # Dimensionality for the 'tranny' model tranny_n_layers: int = 12 # Number of layers for the 'tranny' model claudeson_dim: int = 512 # Dimensionality for the 'claudeson_clone' model claudeson_n_layers: int = 8 # Number of layers for the 'claudeson_clone' model # Multi-modal fusion parameters fusion_dim: int = 1536 # Dimensionality of the fused multi-modal representation fusion_n_layers: int = 4 # Number of fusion layers # Training parameters learning_rate: float = 1e-4 # Learning rate weight_decay: float = 0.01 # Weight decay (L2 regularization) batch_size: int = 8 # Reduced batch size for potentially larger memory footprint num_epochs: int = 10 # Number of training epochs # ... (other parameters as needed)