atlas-1-demo / src /config.py
Reverb's picture
Upload folder using huggingface_hub
8eabce6 verified
"""
Configuration Management for Geometric Bayesian KAN
Centralized configuration system using dataclasses for type safety and validation.
"""
from dataclasses import dataclass, field
from typing import List, Optional
import yaml
@dataclass
class GBKANConfig:
"""Configuration for Geometric Bayesian KAN model."""
# Architecture
hidden_dim: int = 128
num_layers: int = 4
num_frequencies: int = 8
max_l: int = 2 # Maximum spherical harmonic degree
cutoff: float = 5.0 # Angstroms
prior_std: float = 1.0
# Input/Output
node_feature_dim: int = 11 # Atomic number one-hot encoding
output_dim: int = 1 # Single property prediction
# Pooling
pooling: str = "mean" # "mean", "sum", or "max"
def __post_init__(self):
"""Validate configuration."""
assert self.hidden_dim > 0, "hidden_dim must be positive"
assert self.num_layers > 0, "num_layers must be positive"
assert self.cutoff > 0, "cutoff must be positive"
assert self.pooling in ["mean", "sum", "max"], "Invalid pooling type"
@dataclass
class TrainingConfig:
"""Configuration for training."""
# Optimization
batch_size: int = 32
learning_rate: float = 1e-3
weight_decay: float = 1e-5
num_epochs: int = 500
warmup_epochs: int = 10
# KL annealing
kl_beta_start: float = 0.0
kl_beta_end: float = 1.0
kl_beta_schedule: str = "linear" # "linear", "cosine", or "none"
kl_warmup_epochs: int = 0 # Number of epochs with beta=0 before annealing starts
# Early stopping
patience: int = 50
min_delta: float = 1e-4
# Checkpointing
save_every: int = 10
checkpoint_dir: str = "checkpoints"
# Logging
log_every: int = 10
use_wandb: bool = False
wandb_project: str = "geometric-bayesian-kan"
@dataclass
class DataConfig:
"""Configuration for datasets."""
# Dataset
dataset_name: str = "QM9" # "QM9", "MD17", "PDBbind", etc.
target_property: str = "homo" # For QM9: homo, lumo, gap, etc.
# Splits
train_split: float = 0.8
val_split: float = 0.1
test_split: float = 0.1
# Data augmentation
use_rotation_augmentation: bool = True
use_translation_augmentation: bool = False
# Preprocessing
normalize_targets: bool = True
max_num_atoms: Optional[int] = None
# Hugging Face
hf_dataset_path: str = "yairschiff/qm9"
cache_dir: Optional[str] = None
@dataclass
class UncertaintyConfig:
"""Configuration for uncertainty quantification."""
# Monte Carlo sampling
num_mc_samples: int = 50
# Calibration
calibration_bins: int = 10
confidence_level: float = 0.95
# Active learning
acquisition_function: str = "max_uncertainty" # "max_uncertainty", "expected_improvement", "thompson"
query_budget: int = 100
@dataclass
class ExperimentConfig:
"""Complete experiment configuration."""
model: GBKANConfig = field(default_factory=GBKANConfig)
training: TrainingConfig = field(default_factory=TrainingConfig)
data: DataConfig = field(default_factory=DataConfig)
uncertainty: UncertaintyConfig = field(default_factory=UncertaintyConfig)
# Hardware
device: str = "cuda"
num_workers: int = 4
seed: int = 42
@classmethod
def from_yaml(cls, yaml_path: str) -> "ExperimentConfig":
"""Load configuration from YAML file."""
with open(yaml_path, 'r') as f:
config_dict = yaml.safe_load(f)
return cls(
model=GBKANConfig(**config_dict.get('model', {})),
training=TrainingConfig(**config_dict.get('training', {})),
data=DataConfig(**config_dict.get('data', {})),
uncertainty=UncertaintyConfig(**config_dict.get('uncertainty', {})),
device=config_dict.get('device', 'cuda'),
num_workers=config_dict.get('num_workers', 4),
seed=config_dict.get('seed', 42),
)
def to_yaml(self, yaml_path: str):
"""Save configuration to YAML file."""
config_dict = {
'model': self.model.__dict__,
'training': self.training.__dict__,
'data': self.data.__dict__,
'uncertainty': self.uncertainty.__dict__,
'device': self.device,
'num_workers': self.num_workers,
'seed': self.seed,
}
with open(yaml_path, 'w') as f:
yaml.dump(config_dict, f, default_flow_style=False)
# Default configurations for different experiments
DEFAULT_QM9_CONFIG = ExperimentConfig(
data=DataConfig(
dataset_name="QM9",
target_property="homo",
hf_dataset_path="yairschiff/qm9"
)
)
DEFAULT_MD17_CONFIG = ExperimentConfig(
data=DataConfig(
dataset_name="MD17",
target_property="energy",
),
model=GBKANConfig(num_layers=6) # Deeper for force prediction
)