File size: 4,421 Bytes
aa677e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import yaml
@dataclass
class KnobSpec:
type: str # "int" | "float"
min: float
max: float
@dataclass
class RewardSpec:
weights: Dict[str, float]
wns_candidates: List[str]
area_candidates: List[str]
power_candidates: List[str]
@dataclass
class BudgetSpec:
total_actions: int
max_expensive: int
@dataclass
class ExperimentSpec:
name: str
seed: int
db_path: str
out_dir: str
orfs_flow_dir: Optional[str]
@dataclass
class DesignSpec:
platform: str
design: str
design_config: str # relative to ORFS flow dir
@dataclass
class FlowSpec:
fidelities: List[str]
targets: Dict[str, str]
@dataclass
class TuningSpec:
agent: str
budget: BudgetSpec
knobs: Dict[str, KnobSpec]
@dataclass
class Config:
experiment: ExperimentSpec
design: DesignSpec
flow: FlowSpec
tuning: TuningSpec
reward: RewardSpec
def load_config(path: str) -> Config:
with open(path, "r", encoding="utf-8") as f:
d = yaml.safe_load(f)
exp = d["experiment"]
design = d["design"]
flow = d["flow"]
tuning = d["tuning"]
reward = d["reward"]
knobs: Dict[str, KnobSpec] = {}
for k, ks in tuning["knobs"].items():
knobs[k] = KnobSpec(
type=str(ks["type"]),
min=float(ks["min"]),
max=float(ks["max"]),
)
cfg = Config(
experiment=ExperimentSpec(
name=str(exp["name"]),
seed=int(exp.get("seed", 0)),
db_path=str(exp.get("db_path", "runs/experiment.sqlite")),
out_dir=str(exp.get("out_dir", "runs")),
orfs_flow_dir=exp.get("orfs_flow_dir", None),
),
design=DesignSpec(
platform=str(design["platform"]),
design=str(design["design"]),
design_config=str(design["design_config"]),
),
flow=FlowSpec(
fidelities=list(flow["fidelities"]),
targets=dict(flow["targets"]),
),
tuning=TuningSpec(
agent=str(tuning["agent"]),
budget=BudgetSpec(
total_actions=int(tuning["budget"]["total_actions"]),
max_expensive=int(tuning["budget"]["max_expensive"]),
),
knobs=knobs,
),
reward=RewardSpec(
weights=dict(reward["weights"]),
wns_candidates=list(reward["keys"]["wns_candidates"]),
area_candidates=list(reward["keys"]["area_candidates"]),
power_candidates=list(reward["keys"]["power_candidates"]),
),
)
# Validate configuration
_validate_config(cfg)
return cfg
def _validate_config(cfg: Config) -> None:
"""Validate configuration values."""
# Validate budget
if cfg.tuning.budget.total_actions <= 0:
raise ValueError(f"total_actions must be > 0, got {cfg.tuning.budget.total_actions}")
if cfg.tuning.budget.max_expensive < 0:
raise ValueError(f"max_expensive must be >= 0, got {cfg.tuning.budget.max_expensive}")
if cfg.tuning.budget.max_expensive > cfg.tuning.budget.total_actions:
raise ValueError(
f"max_expensive ({cfg.tuning.budget.max_expensive}) cannot exceed "
f"total_actions ({cfg.tuning.budget.total_actions})"
)
# Validate fidelities
if not cfg.flow.fidelities:
raise ValueError("flow.fidelities cannot be empty")
# Validate knobs
if not cfg.tuning.knobs:
raise ValueError("tuning.knobs cannot be empty")
for name, spec in cfg.tuning.knobs.items():
if spec.min >= spec.max:
raise ValueError(f"Knob {name}: min ({spec.min}) must be < max ({spec.max})")
if spec.type not in ("int", "float"):
raise ValueError(f"Knob {name}: type must be 'int' or 'float', got '{spec.type}'")
# Validate reward weights
if not cfg.reward.weights:
raise ValueError("reward.weights cannot be empty")
# Validate reward candidates
if not cfg.reward.wns_candidates and not cfg.reward.area_candidates and not cfg.reward.power_candidates:
raise ValueError("At least one reward candidate list must be non-empty")
# Note: design_config validation happens later when ORFS dir is known
|