SamChYe's picture
Publish EdgeEDA agent
aa677e3 verified
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import yaml
@dataclass
class KnobSpec:
type: str # "int" | "float"
min: float
max: float
@dataclass
class RewardSpec:
weights: Dict[str, float]
wns_candidates: List[str]
area_candidates: List[str]
power_candidates: List[str]
@dataclass
class BudgetSpec:
total_actions: int
max_expensive: int
@dataclass
class ExperimentSpec:
name: str
seed: int
db_path: str
out_dir: str
orfs_flow_dir: Optional[str]
@dataclass
class DesignSpec:
platform: str
design: str
design_config: str # relative to ORFS flow dir
@dataclass
class FlowSpec:
fidelities: List[str]
targets: Dict[str, str]
@dataclass
class TuningSpec:
agent: str
budget: BudgetSpec
knobs: Dict[str, KnobSpec]
@dataclass
class Config:
experiment: ExperimentSpec
design: DesignSpec
flow: FlowSpec
tuning: TuningSpec
reward: RewardSpec
def load_config(path: str) -> Config:
with open(path, "r", encoding="utf-8") as f:
d = yaml.safe_load(f)
exp = d["experiment"]
design = d["design"]
flow = d["flow"]
tuning = d["tuning"]
reward = d["reward"]
knobs: Dict[str, KnobSpec] = {}
for k, ks in tuning["knobs"].items():
knobs[k] = KnobSpec(
type=str(ks["type"]),
min=float(ks["min"]),
max=float(ks["max"]),
)
cfg = Config(
experiment=ExperimentSpec(
name=str(exp["name"]),
seed=int(exp.get("seed", 0)),
db_path=str(exp.get("db_path", "runs/experiment.sqlite")),
out_dir=str(exp.get("out_dir", "runs")),
orfs_flow_dir=exp.get("orfs_flow_dir", None),
),
design=DesignSpec(
platform=str(design["platform"]),
design=str(design["design"]),
design_config=str(design["design_config"]),
),
flow=FlowSpec(
fidelities=list(flow["fidelities"]),
targets=dict(flow["targets"]),
),
tuning=TuningSpec(
agent=str(tuning["agent"]),
budget=BudgetSpec(
total_actions=int(tuning["budget"]["total_actions"]),
max_expensive=int(tuning["budget"]["max_expensive"]),
),
knobs=knobs,
),
reward=RewardSpec(
weights=dict(reward["weights"]),
wns_candidates=list(reward["keys"]["wns_candidates"]),
area_candidates=list(reward["keys"]["area_candidates"]),
power_candidates=list(reward["keys"]["power_candidates"]),
),
)
# Validate configuration
_validate_config(cfg)
return cfg
def _validate_config(cfg: Config) -> None:
"""Validate configuration values."""
# Validate budget
if cfg.tuning.budget.total_actions <= 0:
raise ValueError(f"total_actions must be > 0, got {cfg.tuning.budget.total_actions}")
if cfg.tuning.budget.max_expensive < 0:
raise ValueError(f"max_expensive must be >= 0, got {cfg.tuning.budget.max_expensive}")
if cfg.tuning.budget.max_expensive > cfg.tuning.budget.total_actions:
raise ValueError(
f"max_expensive ({cfg.tuning.budget.max_expensive}) cannot exceed "
f"total_actions ({cfg.tuning.budget.total_actions})"
)
# Validate fidelities
if not cfg.flow.fidelities:
raise ValueError("flow.fidelities cannot be empty")
# Validate knobs
if not cfg.tuning.knobs:
raise ValueError("tuning.knobs cannot be empty")
for name, spec in cfg.tuning.knobs.items():
if spec.min >= spec.max:
raise ValueError(f"Knob {name}: min ({spec.min}) must be < max ({spec.max})")
if spec.type not in ("int", "float"):
raise ValueError(f"Knob {name}: type must be 'int' or 'float', got '{spec.type}'")
# Validate reward weights
if not cfg.reward.weights:
raise ValueError("reward.weights cannot be empty")
# Validate reward candidates
if not cfg.reward.wns_candidates and not cfg.reward.area_candidates and not cfg.reward.power_candidates:
raise ValueError("At least one reward candidate list must be non-empty")
# Note: design_config validation happens later when ORFS dir is known