|
|
|
|
|
import yaml |
|
|
from typing import Dict, Any, Optional |
|
|
from dataclasses import dataclass |
|
|
|
|
|
@dataclass |
|
|
class Config: |
|
|
"""Single config class - no inheritance needed""" |
|
|
|
|
|
|
|
|
experiment_name: str |
|
|
seed: int = 42 |
|
|
|
|
|
|
|
|
dataset_name: str = "cifar10" |
|
|
image_size: int = 64 |
|
|
num_channels: Optional[int] = None |
|
|
data_path: str = "./data" |
|
|
download: bool = True |
|
|
use_latents: bool = False |
|
|
latent_data_path: Optional[str] = None |
|
|
split_strategy: str = "global" |
|
|
preclustered_data_path: Optional[str] = None |
|
|
train_ratio: float = 0.95 |
|
|
|
|
|
|
|
|
clustering_method: Optional[str] = None |
|
|
num_clusters: int = 1 |
|
|
manual_mapping: Optional[Dict[int, int]] = None |
|
|
|
|
|
|
|
|
num_experts: int = 1 |
|
|
expert_architecture: str = "unet" |
|
|
router_architecture: str = "none" |
|
|
router_pretrained: bool = True |
|
|
clip_tokenizer_name: str = "openai/clip-vit-large-patch14" |
|
|
|
|
|
|
|
|
batch_size: int = 32 |
|
|
num_epochs: int = 20 |
|
|
learning_rate: float = 1e-4 |
|
|
optimizer: str = "adamw" |
|
|
mixed_precision: bool = True |
|
|
num_gpus: int = 1 |
|
|
distributed: bool = False |
|
|
train_router_jointly: bool = False |
|
|
weight_decay: float = 0 |
|
|
use_lr_scheduler: bool = True |
|
|
warmup_steps: int = 0 |
|
|
warmup_factor: float = 0.1 |
|
|
grad_accum_steps: int = 1 |
|
|
use_amp: bool = True |
|
|
imagenet_pretrain_checkpoint: Optional[str] = None |
|
|
|
|
|
|
|
|
use_class_weights: bool = False |
|
|
weight_smoothing: float = 0.0 |
|
|
|
|
|
|
|
|
new_dataset_learning_rate: Optional[float] = None |
|
|
reset_optimizer: bool = True |
|
|
reset_scheduler: bool = True |
|
|
reset_epoch: bool = True |
|
|
reset_ema: bool = False |
|
|
|
|
|
|
|
|
expert_parallel: bool = False |
|
|
target_expert_id: int = 0 |
|
|
target_gpu_id: int = 0 |
|
|
|
|
|
|
|
|
compute_fid: bool = False |
|
|
fid_every: int = 5000 |
|
|
fid_num_samples: int = 5000 |
|
|
fid_batch_size: int = 50 |
|
|
|
|
|
|
|
|
use_ema: bool = True |
|
|
ema_decay: float = 0.9999 |
|
|
ema_update_every: int = 1 |
|
|
|
|
|
|
|
|
expert_objectives: Optional[Dict[int, str]] = None |
|
|
default_objective: str = "fm" |
|
|
|
|
|
|
|
|
schedule_type: str = "linear_interp" |
|
|
expert_schedule_types: Optional[Dict[int, str]] = None |
|
|
|
|
|
|
|
|
use_consistency_loss: bool = False |
|
|
consistency_loss_weight: float = 0.1 |
|
|
|
|
|
|
|
|
expert_params: Dict[str, Any] = None |
|
|
router_params: Dict[str, Any] = None |
|
|
video_config: Dict[str, Any] = None |
|
|
|
|
|
|
|
|
sampling_strategy: str = "top1" |
|
|
num_inference_steps: int = 50 |
|
|
|
|
|
|
|
|
beta_start: float = 0.0001 |
|
|
beta_end: float = 0.02 |
|
|
beta_schedule: str = "linear" |
|
|
max_text_length: int = 77 |
|
|
|
|
|
|
|
|
checkpoint_dir: str = "./outputs/checkpoints" |
|
|
log_dir: str = "./outputs/logs" |
|
|
|
|
|
def __post_init__(self) -> None: |
|
|
|
|
|
if self.expert_params is None: |
|
|
self.expert_params = {} |
|
|
if self.router_params is None: |
|
|
self.router_params = {} |
|
|
if self.video_config is None: |
|
|
self.video_config = {} |
|
|
|
|
|
|
|
|
if self.num_channels is None: |
|
|
if self.use_latents: |
|
|
self.num_channels = 4 |
|
|
elif self.dataset_name in ["mnist", "fashionmnist"]: |
|
|
self.num_channels = 1 |
|
|
else: |
|
|
self.num_channels = 3 |
|
|
|
|
|
|
|
|
valid_objectives = {"ddpm", "fm", "rf"} |
|
|
|
|
|
|
|
|
if self.default_objective not in valid_objectives: |
|
|
raise ValueError(f"default_objective must be one of {valid_objectives}, got {self.default_objective}") |
|
|
|
|
|
|
|
|
if self.expert_objectives is None: |
|
|
self.expert_objectives = {i: self.default_objective for i in range(self.num_experts)} |
|
|
else: |
|
|
|
|
|
for expert_id, obj_type in self.expert_objectives.items(): |
|
|
if obj_type not in valid_objectives: |
|
|
raise ValueError(f"Expert {expert_id} has invalid objective '{obj_type}'. Must be one of {valid_objectives}") |
|
|
|
|
|
|
|
|
for expert_id in range(self.num_experts): |
|
|
if expert_id not in self.expert_objectives: |
|
|
self.expert_objectives[expert_id] = self.default_objective |
|
|
|
|
|
|
|
|
valid_schedules = {"cosine", "linear_beta", "linear_interp"} |
|
|
|
|
|
|
|
|
if self.schedule_type not in valid_schedules: |
|
|
raise ValueError(f"schedule_type must be one of {valid_schedules}, got {self.schedule_type}") |
|
|
|
|
|
|
|
|
if self.expert_schedule_types is not None: |
|
|
for expert_id, sched_type in self.expert_schedule_types.items(): |
|
|
if sched_type not in valid_schedules: |
|
|
raise ValueError(f"Expert {expert_id} has invalid schedule '{sched_type}'. Must be one of {valid_schedules}") |
|
|
|
|
|
@classmethod |
|
|
def from_yaml(cls, config_path: str) -> 'Config': |
|
|
with open(config_path, 'r') as f: |
|
|
config_dict = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
config_dict.setdefault('expert_params', {}) |
|
|
config_dict.setdefault('router_params', {}) |
|
|
|
|
|
|
|
|
if 'num_experts' not in config_dict: |
|
|
num_clusters = config_dict.get('num_clusters', 1) |
|
|
config_dict['num_experts'] = max(1, num_clusters) |
|
|
|
|
|
return cls(**config_dict) |
|
|
|
|
|
@property |
|
|
def is_monolithic(self) -> bool: |
|
|
return self.num_experts == 1 |
|
|
|
|
|
|
|
|
@property |
|
|
def num_classes(self) -> int: |
|
|
dataset_classes = { |
|
|
"mnist": 10, "fashionmnist": 10, |
|
|
"cifar10": 10, "cifar100": 100, |
|
|
"celeba": 0, |
|
|
"butterfly": 1, |
|
|
"laion": 0 |
|
|
} |
|
|
return dataset_classes.get(self.dataset_name, 10) |
|
|
|
|
|
def load_config(config_path: str) -> Config: |
|
|
"""Simple config loader""" |
|
|
return Config.from_yaml(config_path) |