baguette / src /config.py
nbagel's picture
Initial upload: Paris MoE inference code and weights
4dec1ca verified
# src/config.py
import yaml
from typing import Dict, Any, Optional
from dataclasses import dataclass
@dataclass
class Config:
"""Single config class - no inheritance needed"""
# Experiment
experiment_name: str
seed: int = 42
# Dataset
dataset_name: str = "cifar10"
image_size: int = 64
num_channels: Optional[int] = None # If None, auto-determined based on dataset/latents
data_path: str = "./data"
download: bool = True
use_latents: bool = False # Whether to use VAE latents instead of raw images
latent_data_path: Optional[str] = None # Path to latent dataset JSON
split_strategy: str = "global" # "global" or "per_cluster"
preclustered_data_path: Optional[str] = None # Path to pre-clustered data
train_ratio: float = 0.95 # Train/val split ratio
# Clustering (None for monolithic)
clustering_method: Optional[str] = None # "manual", "kmeans", <- note that we dont support dino as an on-the-fly clustering method
num_clusters: int = 1
manual_mapping: Optional[Dict[int, int]] = None
# Model
num_experts: int = 1 # 1 = monolithic, >1 = DDM
expert_architecture: str = "unet" # "unet", "dit", "simple_cnn"
router_architecture: str = "none" # "vit", "cnn", "dit", "none"
router_pretrained: bool = True
clip_tokenizer_name: str = "openai/clip-vit-large-patch14"
# Training
batch_size: int = 32
num_epochs: int = 20
learning_rate: float = 1e-4
optimizer: str = "adamw"
mixed_precision: bool = True
num_gpus: int = 1
distributed: bool = False
train_router_jointly: bool = False
weight_decay: float = 0
use_lr_scheduler: bool = True
warmup_steps: int = 0 # Learning rate warmup steps
warmup_factor: float = 0.1 # Learning rate warmup factor
grad_accum_steps: int = 1
use_amp: bool = True
imagenet_pretrain_checkpoint: Optional[str] = None
# Cluster imbalance handling
use_class_weights: bool = False # Enable class weighting for imbalanced clusters
weight_smoothing: float = 0.0 # Weight smoothing factor (0.0-1.0)
# New dataset training options
new_dataset_learning_rate: Optional[float] = None
reset_optimizer: bool = True
reset_scheduler: bool = True
reset_epoch: bool = True
reset_ema: bool = False
# Decentralized training
expert_parallel: bool = False
target_expert_id: int = 0
target_gpu_id: int = 0
# FID evaluation
compute_fid: bool = False
fid_every: int = 5000
fid_num_samples: int = 5000
fid_batch_size: int = 50
# EMA parameters
use_ema: bool = True
ema_decay: float = 0.9999
ema_update_every: int = 1
# Heterogeneous objectives
expert_objectives: Optional[Dict[int, str]] = None # {expert_id: "ddpm"|"fm"|"rf"}
default_objective: str = "fm" # Default if expert_objectives not specified
# Schedule configuration (NEW)
schedule_type: str = "linear_interp" # Default for backward compatibility
expert_schedule_types: Optional[Dict[int, str]] = None # Per-expert schedules for Strategy B
# Consistency loss (NEW)
use_consistency_loss: bool = False
consistency_loss_weight: float = 0.1
# Model parameters (flexible dicts)
expert_params: Dict[str, Any] = None
router_params: Dict[str, Any] = None
video_config: Dict[str, Any] = None # Video-specific parameters (temporal_frames, latent_height, etc.)
# Inference
sampling_strategy: str = "top1" # "top1", "topk", "full", "monolithic"
num_inference_steps: int = 50
# Diffusion settings
beta_start: float = 0.0001
beta_end: float = 0.02
beta_schedule: str = "linear"
max_text_length: int = 77
# Paths
checkpoint_dir: str = "./outputs/checkpoints"
log_dir: str = "./outputs/logs"
def __post_init__(self) -> None:
# Set defaults for missing fields
if self.expert_params is None:
self.expert_params = {}
if self.router_params is None:
self.router_params = {}
if self.video_config is None:
self.video_config = {}
# Auto-determine num_channels if not explicitly set
if self.num_channels is None:
if self.use_latents:
self.num_channels = 4 # VAE latent channels
elif self.dataset_name in ["mnist", "fashionmnist"]:
self.num_channels = 1
else:
self.num_channels = 3
# Initialize and validate expert_objectives
valid_objectives = {"ddpm", "fm", "rf"}
# Validate default_objective
if self.default_objective not in valid_objectives:
raise ValueError(f"default_objective must be one of {valid_objectives}, got {self.default_objective}")
# Initialize expert_objectives if None
if self.expert_objectives is None:
self.expert_objectives = {i: self.default_objective for i in range(self.num_experts)}
else:
# Validate all objective types
for expert_id, obj_type in self.expert_objectives.items():
if obj_type not in valid_objectives:
raise ValueError(f"Expert {expert_id} has invalid objective '{obj_type}'. Must be one of {valid_objectives}")
# Ensure all expert IDs have objectives assigned
for expert_id in range(self.num_experts):
if expert_id not in self.expert_objectives:
self.expert_objectives[expert_id] = self.default_objective
# Validate schedule types (NEW)
valid_schedules = {"cosine", "linear_beta", "linear_interp"}
# Validate default schedule_type
if self.schedule_type not in valid_schedules:
raise ValueError(f"schedule_type must be one of {valid_schedules}, got {self.schedule_type}")
# Validate expert_schedule_types if provided
if self.expert_schedule_types is not None:
for expert_id, sched_type in self.expert_schedule_types.items():
if sched_type not in valid_schedules:
raise ValueError(f"Expert {expert_id} has invalid schedule '{sched_type}'. Must be one of {valid_schedules}")
@classmethod
def from_yaml(cls, config_path: str) -> 'Config':
with open(config_path, 'r') as f:
config_dict = yaml.safe_load(f)
# Set defaults for missing fields
config_dict.setdefault('expert_params', {})
config_dict.setdefault('router_params', {})
# If num_experts is not specified, default to num_clusters (or 1 if num_clusters is not set)
if 'num_experts' not in config_dict:
num_clusters = config_dict.get('num_clusters', 1)
config_dict['num_experts'] = max(1, num_clusters)
return cls(**config_dict)
@property
def is_monolithic(self) -> bool:
return self.num_experts == 1
@property
def num_classes(self) -> int:
dataset_classes = {
"mnist": 10, "fashionmnist": 10,
"cifar10": 10, "cifar100": 100,
"celeba": 0, # No class conditioning
"butterfly": 1, # Single class for butterflies
"laion": 0 # No class conditioning for LAION
}
return dataset_classes.get(self.dataset_name, 10)
def load_config(config_path: str) -> Config:
"""Simple config loader"""
return Config.from_yaml(config_path)