| from dataclasses import dataclass |
| from typing import Optional, Dict, Any |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| """Configuration for the compact AI model.""" |
|
|
| model_size: str = "small" |
| dim: int = 512 |
| layers: int = 12 |
| heads: int = 8 |
| vocab_size: int = 32000 |
| max_seq_len: int = 4096 |
| dropout: float = 0.1 |
| quantization: Optional[str] = None |
| flash_attention: bool = True |
| rope_scaling: bool = True |
|
|
|
|
| @dataclass |
| class InterleavedThinkingConfig: |
| """Configuration for interleaved thinking mechanism.""" |
|
|
| max_reasoning_paths: int = 3 |
| reasoning_depth: int = 4 |
| early_stop_threshold: float = 0.85 |
| token_budget: int = 512 |
| memory_compression: bool = True |
| dynamic_depth: bool = True |
| confidence_weight: float = 0.7 |
| diversity_weight: float = 0.3 |
| |
| uncertainty_estimation: bool = True |
| hierarchical_paths: bool = True |
| num_hierarchy_levels: int = 3 |
| attention_fusion: bool = True |
| task_specific_thresholds: bool = True |
| path_specialization: bool = True |
| adaptive_compression: bool = True |
| visualization_enabled: bool = True |
|
|
|
|
| @dataclass |
| class TrainingConfig: |
| """Configuration for training.""" |
|
|
| learning_rate: float = 1e-4 |
| batch_size: int = 8 |
| gradient_accumulation_steps: int = 4 |
| num_epochs: int = 10 |
| weight_decay: float = 0.01 |
| warmup_steps: int = 1000 |
| max_grad_norm: float = 1.0 |
| log_interval: int = 10 |
| save_interval: int = 1000 |
| eval_interval: int = 500 |
| mixed_precision: bool = True |
| gradient_checkpointing: bool = False |
|
|
|
|
| @dataclass |
| class APIConfig: |
| """Configuration for API server.""" |
|
|
| host: str = "0.0.0.0" |
| port: int = 8000 |
| workers: int = 1 |
| log_level: str = "info" |
| cache_size: int = 1000 |
| max_concurrent_requests: int = 10 |
| request_timeout: int = 300 |
| enable_cors: bool = True |
|
|
|
|
| @dataclass |
| class Config: |
| """Main configuration class.""" |
|
|
| model: ModelConfig |
| thinking: InterleavedThinkingConfig |
| training: TrainingConfig |
| api: APIConfig |
|
|
| @classmethod |
| def get_balanced_config(cls) -> "Config": |
| """Get a balanced configuration for production use.""" |
| return cls( |
| model=ModelConfig(), |
| thinking=InterleavedThinkingConfig(), |
| training=TrainingConfig(), |
| api=APIConfig(), |
| ) |
|
|
| def get_tiny_config(cls) -> "Config": |
| """Get a tiny configuration for quick testing.""" |
| return cls( |
| model=ModelConfig( |
| model_size="tiny", |
| dim=256, |
| layers=8, |
| heads=8, |
| ), |
| thinking=InterleavedThinkingConfig( |
| max_reasoning_paths=2, |
| reasoning_depth=3, |
| token_budget=256, |
| hierarchical_paths=False, |
| attention_fusion=False, |
| path_specialization=False, |
| ), |
| training=TrainingConfig( |
| batch_size=4, |
| num_epochs=5, |
| ), |
| api=APIConfig(), |
| ) |
|
|
| def get_large_config(cls) -> "Config": |
| """Get a larger configuration for better performance.""" |
| return cls( |
| model=ModelConfig( |
| model_size="medium", |
| dim=768, |
| layers=16, |
| heads=12, |
| ), |
| thinking=InterleavedThinkingConfig( |
| max_reasoning_paths=4, |
| reasoning_depth=6, |
| token_budget=1024, |
| hierarchical_paths=True, |
| attention_fusion=True, |
| path_specialization=True, |
| adaptive_compression=True, |
| ), |
| training=TrainingConfig( |
| batch_size=16, |
| num_epochs=20, |
| ), |
| api=APIConfig(workers=4), |
| ) |
|
|
|
|
| def load_config_from_dict(config_dict: Dict[str, Any]) -> Config: |
| """Load configuration from a dictionary.""" |
| model_config = ModelConfig(**config_dict.get("model", {})) |
| thinking_config = InterleavedThinkingConfig(**config_dict.get("thinking", {})) |
| training_config = TrainingConfig(**config_dict.get("training", {})) |
| api_config = APIConfig(**config_dict.get("api", {})) |
|
|
| return Config( |
| model=model_config, |
| thinking=thinking_config, |
| training=training_config, |
| api=api_config, |
| ) |
|
|
|
|
| def save_config_to_dict(config: Config) -> Dict[str, Any]: |
| """Save configuration to a dictionary.""" |
| return { |
| "model": config.model.__dict__, |
| "thinking": config.thinking.__dict__, |
| "training": config.training.__dict__, |
| "api": config.api.__dict__, |
| } |