likhonsheikh's picture
Upload folder using huggingface_hub
b9b1e87 verified
from dataclasses import dataclass
from typing import Optional, Dict, Any
@dataclass
class ModelConfig:
"""Configuration for the compact AI model."""
model_size: str = "small" # tiny, small, medium
dim: int = 512
layers: int = 12
heads: int = 8
vocab_size: int = 32000
max_seq_len: int = 4096
dropout: float = 0.1
quantization: Optional[str] = None # None, "4bit", "8bit"
flash_attention: bool = True
rope_scaling: bool = True
@dataclass
class InterleavedThinkingConfig:
"""Configuration for interleaved thinking mechanism."""
max_reasoning_paths: int = 3
reasoning_depth: int = 4 # or "adaptive", "complex"
early_stop_threshold: float = 0.85
token_budget: int = 512
memory_compression: bool = True
dynamic_depth: bool = True
confidence_weight: float = 0.7
diversity_weight: float = 0.3
# New parameters for enhanced thinking
uncertainty_estimation: bool = True
hierarchical_paths: bool = True
num_hierarchy_levels: int = 3
attention_fusion: bool = True
task_specific_thresholds: bool = True
path_specialization: bool = True
adaptive_compression: bool = True
visualization_enabled: bool = True
@dataclass
class TrainingConfig:
"""Configuration for training."""
learning_rate: float = 1e-4
batch_size: int = 8
gradient_accumulation_steps: int = 4
num_epochs: int = 10
weight_decay: float = 0.01
warmup_steps: int = 1000
max_grad_norm: float = 1.0
log_interval: int = 10
save_interval: int = 1000
eval_interval: int = 500
mixed_precision: bool = True
gradient_checkpointing: bool = False
@dataclass
class APIConfig:
"""Configuration for API server."""
host: str = "0.0.0.0"
port: int = 8000
workers: int = 1
log_level: str = "info"
cache_size: int = 1000
max_concurrent_requests: int = 10
request_timeout: int = 300
enable_cors: bool = True
@dataclass
class Config:
"""Main configuration class."""
model: ModelConfig
thinking: InterleavedThinkingConfig
training: TrainingConfig
api: APIConfig
@classmethod
def get_balanced_config(cls) -> "Config":
"""Get a balanced configuration for production use."""
return cls(
model=ModelConfig(),
thinking=InterleavedThinkingConfig(),
training=TrainingConfig(),
api=APIConfig(),
)
def get_tiny_config(cls) -> "Config":
"""Get a tiny configuration for quick testing."""
return cls(
model=ModelConfig(
model_size="tiny",
dim=256,
layers=8,
heads=8,
),
thinking=InterleavedThinkingConfig(
max_reasoning_paths=2,
reasoning_depth=3,
token_budget=256,
hierarchical_paths=False,
attention_fusion=False,
path_specialization=False,
),
training=TrainingConfig(
batch_size=4,
num_epochs=5,
),
api=APIConfig(),
)
def get_large_config(cls) -> "Config":
"""Get a larger configuration for better performance."""
return cls(
model=ModelConfig(
model_size="medium",
dim=768,
layers=16,
heads=12,
),
thinking=InterleavedThinkingConfig(
max_reasoning_paths=4,
reasoning_depth=6,
token_budget=1024,
hierarchical_paths=True,
attention_fusion=True,
path_specialization=True,
adaptive_compression=True,
),
training=TrainingConfig(
batch_size=16,
num_epochs=20,
),
api=APIConfig(workers=4),
)
def load_config_from_dict(config_dict: Dict[str, Any]) -> Config:
"""Load configuration from a dictionary."""
model_config = ModelConfig(**config_dict.get("model", {}))
thinking_config = InterleavedThinkingConfig(**config_dict.get("thinking", {}))
training_config = TrainingConfig(**config_dict.get("training", {}))
api_config = APIConfig(**config_dict.get("api", {}))
return Config(
model=model_config,
thinking=thinking_config,
training=training_config,
api=api_config,
)
def save_config_to_dict(config: Config) -> Dict[str, Any]:
"""Save configuration to a dictionary."""
return {
"model": config.model.__dict__,
"thinking": config.thinking.__dict__,
"training": config.training.__dict__,
"api": config.api.__dict__,
}