Spaces:
Build error
Build error
| """ | |
| Configuration Management for Geometric Bayesian KAN | |
| Centralized configuration system using dataclasses for type safety and validation. | |
| """ | |
| from dataclasses import dataclass, field | |
| from typing import List, Optional | |
| import yaml | |
| class GBKANConfig: | |
| """Configuration for Geometric Bayesian KAN model.""" | |
| # Architecture | |
| hidden_dim: int = 128 | |
| num_layers: int = 4 | |
| num_frequencies: int = 8 | |
| max_l: int = 2 # Maximum spherical harmonic degree | |
| cutoff: float = 5.0 # Angstroms | |
| prior_std: float = 1.0 | |
| # Input/Output | |
| node_feature_dim: int = 11 # Atomic number one-hot encoding | |
| output_dim: int = 1 # Single property prediction | |
| # Pooling | |
| pooling: str = "mean" # "mean", "sum", or "max" | |
| def __post_init__(self): | |
| """Validate configuration.""" | |
| assert self.hidden_dim > 0, "hidden_dim must be positive" | |
| assert self.num_layers > 0, "num_layers must be positive" | |
| assert self.cutoff > 0, "cutoff must be positive" | |
| assert self.pooling in ["mean", "sum", "max"], "Invalid pooling type" | |
| class TrainingConfig: | |
| """Configuration for training.""" | |
| # Optimization | |
| batch_size: int = 32 | |
| learning_rate: float = 1e-3 | |
| weight_decay: float = 1e-5 | |
| num_epochs: int = 500 | |
| warmup_epochs: int = 10 | |
| # KL annealing | |
| kl_beta_start: float = 0.0 | |
| kl_beta_end: float = 1.0 | |
| kl_beta_schedule: str = "linear" # "linear", "cosine", or "none" | |
| kl_warmup_epochs: int = 0 # Number of epochs with beta=0 before annealing starts | |
| # Early stopping | |
| patience: int = 50 | |
| min_delta: float = 1e-4 | |
| # Checkpointing | |
| save_every: int = 10 | |
| checkpoint_dir: str = "checkpoints" | |
| # Logging | |
| log_every: int = 10 | |
| use_wandb: bool = False | |
| wandb_project: str = "geometric-bayesian-kan" | |
| class DataConfig: | |
| """Configuration for datasets.""" | |
| # Dataset | |
| dataset_name: str = "QM9" # "QM9", "MD17", "PDBbind", etc. | |
| target_property: str = "homo" # For QM9: homo, lumo, gap, etc. | |
| # Splits | |
| train_split: float = 0.8 | |
| val_split: float = 0.1 | |
| test_split: float = 0.1 | |
| # Data augmentation | |
| use_rotation_augmentation: bool = True | |
| use_translation_augmentation: bool = False | |
| # Preprocessing | |
| normalize_targets: bool = True | |
| max_num_atoms: Optional[int] = None | |
| # Hugging Face | |
| hf_dataset_path: str = "yairschiff/qm9" | |
| cache_dir: Optional[str] = None | |
| class UncertaintyConfig: | |
| """Configuration for uncertainty quantification.""" | |
| # Monte Carlo sampling | |
| num_mc_samples: int = 50 | |
| # Calibration | |
| calibration_bins: int = 10 | |
| confidence_level: float = 0.95 | |
| # Active learning | |
| acquisition_function: str = "max_uncertainty" # "max_uncertainty", "expected_improvement", "thompson" | |
| query_budget: int = 100 | |
| class ExperimentConfig: | |
| """Complete experiment configuration.""" | |
| model: GBKANConfig = field(default_factory=GBKANConfig) | |
| training: TrainingConfig = field(default_factory=TrainingConfig) | |
| data: DataConfig = field(default_factory=DataConfig) | |
| uncertainty: UncertaintyConfig = field(default_factory=UncertaintyConfig) | |
| # Hardware | |
| device: str = "cuda" | |
| num_workers: int = 4 | |
| seed: int = 42 | |
| def from_yaml(cls, yaml_path: str) -> "ExperimentConfig": | |
| """Load configuration from YAML file.""" | |
| with open(yaml_path, 'r') as f: | |
| config_dict = yaml.safe_load(f) | |
| return cls( | |
| model=GBKANConfig(**config_dict.get('model', {})), | |
| training=TrainingConfig(**config_dict.get('training', {})), | |
| data=DataConfig(**config_dict.get('data', {})), | |
| uncertainty=UncertaintyConfig(**config_dict.get('uncertainty', {})), | |
| device=config_dict.get('device', 'cuda'), | |
| num_workers=config_dict.get('num_workers', 4), | |
| seed=config_dict.get('seed', 42), | |
| ) | |
| def to_yaml(self, yaml_path: str): | |
| """Save configuration to YAML file.""" | |
| config_dict = { | |
| 'model': self.model.__dict__, | |
| 'training': self.training.__dict__, | |
| 'data': self.data.__dict__, | |
| 'uncertainty': self.uncertainty.__dict__, | |
| 'device': self.device, | |
| 'num_workers': self.num_workers, | |
| 'seed': self.seed, | |
| } | |
| with open(yaml_path, 'w') as f: | |
| yaml.dump(config_dict, f, default_flow_style=False) | |
| # Default configurations for different experiments | |
| DEFAULT_QM9_CONFIG = ExperimentConfig( | |
| data=DataConfig( | |
| dataset_name="QM9", | |
| target_property="homo", | |
| hf_dataset_path="yairschiff/qm9" | |
| ) | |
| ) | |
| DEFAULT_MD17_CONFIG = ExperimentConfig( | |
| data=DataConfig( | |
| dataset_name="MD17", | |
| target_property="energy", | |
| ), | |
| model=GBKANConfig(num_layers=6) # Deeper for force prediction | |
| ) | |