# src/config.py import yaml from typing import Dict, Any, Optional from dataclasses import dataclass @dataclass class Config: """Single config class - no inheritance needed""" # Experiment experiment_name: str seed: int = 42 # Dataset dataset_name: str = "cifar10" image_size: int = 64 num_channels: Optional[int] = None # If None, auto-determined based on dataset/latents data_path: str = "./data" download: bool = True use_latents: bool = False # Whether to use VAE latents instead of raw images latent_data_path: Optional[str] = None # Path to latent dataset JSON split_strategy: str = "global" # "global" or "per_cluster" preclustered_data_path: Optional[str] = None # Path to pre-clustered data train_ratio: float = 0.95 # Train/val split ratio # Clustering (None for monolithic) clustering_method: Optional[str] = None # "manual", "kmeans", <- note that we dont support dino as an on-the-fly clustering method num_clusters: int = 1 manual_mapping: Optional[Dict[int, int]] = None # Model num_experts: int = 1 # 1 = monolithic, >1 = DDM expert_architecture: str = "unet" # "unet", "dit", "simple_cnn" router_architecture: str = "none" # "vit", "cnn", "dit", "none" router_pretrained: bool = True clip_tokenizer_name: str = "openai/clip-vit-large-patch14" # Training batch_size: int = 32 num_epochs: int = 20 learning_rate: float = 1e-4 optimizer: str = "adamw" mixed_precision: bool = True num_gpus: int = 1 distributed: bool = False train_router_jointly: bool = False weight_decay: float = 0 use_lr_scheduler: bool = True warmup_steps: int = 0 # Learning rate warmup steps warmup_factor: float = 0.1 # Learning rate warmup factor grad_accum_steps: int = 1 use_amp: bool = True imagenet_pretrain_checkpoint: Optional[str] = None # Cluster imbalance handling use_class_weights: bool = False # Enable class weighting for imbalanced clusters weight_smoothing: float = 0.0 # Weight smoothing factor (0.0-1.0) # New dataset training options new_dataset_learning_rate: Optional[float] = None reset_optimizer: bool = True reset_scheduler: bool = True reset_epoch: bool = True reset_ema: bool = False # Decentralized training expert_parallel: bool = False target_expert_id: int = 0 target_gpu_id: int = 0 # FID evaluation compute_fid: bool = False fid_every: int = 5000 fid_num_samples: int = 5000 fid_batch_size: int = 50 # EMA parameters use_ema: bool = True ema_decay: float = 0.9999 ema_update_every: int = 1 # Heterogeneous objectives expert_objectives: Optional[Dict[int, str]] = None # {expert_id: "ddpm"|"fm"|"rf"} default_objective: str = "fm" # Default if expert_objectives not specified # Schedule configuration (NEW) schedule_type: str = "linear_interp" # Default for backward compatibility expert_schedule_types: Optional[Dict[int, str]] = None # Per-expert schedules for Strategy B # Consistency loss (NEW) use_consistency_loss: bool = False consistency_loss_weight: float = 0.1 # Model parameters (flexible dicts) expert_params: Dict[str, Any] = None router_params: Dict[str, Any] = None video_config: Dict[str, Any] = None # Video-specific parameters (temporal_frames, latent_height, etc.) # Inference sampling_strategy: str = "top1" # "top1", "topk", "full", "monolithic" num_inference_steps: int = 50 # Diffusion settings beta_start: float = 0.0001 beta_end: float = 0.02 beta_schedule: str = "linear" max_text_length: int = 77 # Paths checkpoint_dir: str = "./outputs/checkpoints" log_dir: str = "./outputs/logs" def __post_init__(self) -> None: # Set defaults for missing fields if self.expert_params is None: self.expert_params = {} if self.router_params is None: self.router_params = {} if self.video_config is None: self.video_config = {} # Auto-determine num_channels if not explicitly set if self.num_channels is None: if self.use_latents: self.num_channels = 4 # VAE latent channels elif self.dataset_name in ["mnist", "fashionmnist"]: self.num_channels = 1 else: self.num_channels = 3 # Initialize and validate expert_objectives valid_objectives = {"ddpm", "fm", "rf"} # Validate default_objective if self.default_objective not in valid_objectives: raise ValueError(f"default_objective must be one of {valid_objectives}, got {self.default_objective}") # Initialize expert_objectives if None if self.expert_objectives is None: self.expert_objectives = {i: self.default_objective for i in range(self.num_experts)} else: # Validate all objective types for expert_id, obj_type in self.expert_objectives.items(): if obj_type not in valid_objectives: raise ValueError(f"Expert {expert_id} has invalid objective '{obj_type}'. Must be one of {valid_objectives}") # Ensure all expert IDs have objectives assigned for expert_id in range(self.num_experts): if expert_id not in self.expert_objectives: self.expert_objectives[expert_id] = self.default_objective # Validate schedule types (NEW) valid_schedules = {"cosine", "linear_beta", "linear_interp"} # Validate default schedule_type if self.schedule_type not in valid_schedules: raise ValueError(f"schedule_type must be one of {valid_schedules}, got {self.schedule_type}") # Validate expert_schedule_types if provided if self.expert_schedule_types is not None: for expert_id, sched_type in self.expert_schedule_types.items(): if sched_type not in valid_schedules: raise ValueError(f"Expert {expert_id} has invalid schedule '{sched_type}'. Must be one of {valid_schedules}") @classmethod def from_yaml(cls, config_path: str) -> 'Config': with open(config_path, 'r') as f: config_dict = yaml.safe_load(f) # Set defaults for missing fields config_dict.setdefault('expert_params', {}) config_dict.setdefault('router_params', {}) # If num_experts is not specified, default to num_clusters (or 1 if num_clusters is not set) if 'num_experts' not in config_dict: num_clusters = config_dict.get('num_clusters', 1) config_dict['num_experts'] = max(1, num_clusters) return cls(**config_dict) @property def is_monolithic(self) -> bool: return self.num_experts == 1 @property def num_classes(self) -> int: dataset_classes = { "mnist": 10, "fashionmnist": 10, "cifar10": 10, "cifar100": 100, "celeba": 0, # No class conditioning "butterfly": 1, # Single class for butterflies "laion": 0 # No class conditioning for LAION } return dataset_classes.get(self.dataset_name, 10) def load_config(config_path: str) -> Config: """Simple config loader""" return Config.from_yaml(config_path)