SamChYe's picture
Publish EdgeEDA agent
aa677e3 verified
from __future__ import annotations
import argparse
import logging
import os
from typing import Any, Dict, Optional
from tqdm import tqdm
from edgeeda.config import load_config, Config
from edgeeda.utils import seed_everything, ensure_dir
from edgeeda.store import TrialStore, TrialRecord
from edgeeda.orfs.runner import ORFSRunner
from edgeeda.orfs.metrics import find_best_metadata_json, load_json
from edgeeda.reward import compute_reward
from edgeeda.viz import export_trials, make_plots
from edgeeda.agents.random_search import RandomSearchAgent
from edgeeda.agents.successive_halving import SuccessiveHalvingAgent
from edgeeda.agents.surrogate_ucb import SurrogateUCBAgent
AGENTS = {
"random": RandomSearchAgent,
"successive_halving": SuccessiveHalvingAgent,
"surrogate_ucb": SurrogateUCBAgent,
}
def _select_agent(cfg: Config):
name = cfg.tuning.agent
if name not in AGENTS:
raise ValueError(f"Unknown agent: {name}. Choose from {list(AGENTS.keys())}")
return AGENTS[name](cfg)
def _setup_logging(cfg: Config) -> None:
"""Setup logging to both file and console."""
log_dir = cfg.experiment.out_dir
ensure_dir(log_dir)
log_file = os.path.join(log_dir, "tuning.log")
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
logging.info(f"Logging initialized. Log file: {log_file}")
def cmd_tune(args: argparse.Namespace) -> None:
cfg = load_config(args.config)
if args.budget is not None:
cfg.tuning.budget.total_actions = int(args.budget)
seed_everything(cfg.experiment.seed)
ensure_dir(cfg.experiment.out_dir)
_setup_logging(cfg)
logging.info(f"Starting tuning experiment: {cfg.experiment.name}")
logging.info(f"Agent: {cfg.tuning.agent}, Budget: {cfg.tuning.budget.total_actions} actions")
logging.info(f"Platform: {cfg.design.platform}, Design: {cfg.design.design}")
orfs_flow_dir = cfg.experiment.orfs_flow_dir or os.environ.get("ORFS_FLOW_DIR")
if not orfs_flow_dir:
raise RuntimeError("ORFS flow dir missing. Set experiment.orfs_flow_dir or export ORFS_FLOW_DIR=/path/to/ORFS/flow")
logging.info(f"ORFS flow directory: {orfs_flow_dir}")
runner = ORFSRunner(orfs_flow_dir)
store = TrialStore(cfg.experiment.db_path)
agent = _select_agent(cfg)
expensive_set = set(cfg.flow.fidelities[-1:]) # last stage treated as expensive
expensive_used = 0
for i in tqdm(range(cfg.tuning.budget.total_actions), desc="actions"):
action = agent.propose()
fidelity = action.fidelity
# enforce max expensive budget
if fidelity in expensive_set and expensive_used >= cfg.tuning.budget.max_expensive:
# downgrade to cheaper stage
fidelity = cfg.flow.fidelities[0]
action = type(action)(variant=action.variant, fidelity=fidelity, knobs=action.knobs)
make_target = cfg.flow.targets.get(fidelity, fidelity)
logging.info(f"Action {i+1}/{cfg.tuning.budget.total_actions}: variant={action.variant}, "
f"fidelity={action.fidelity}, knobs={action.knobs}")
# run ORFS make
logging.debug(f"Running: {make_target} for variant {action.variant}")
rr = runner.run_make(
target=make_target,
design_config=cfg.design.design_config,
flow_variant=action.variant,
overrides={k: str(v) for k, v in action.knobs.items()},
timeout_sec=args.timeout,
)
ok = (rr.return_code == 0)
if not ok:
logging.warning(f"Trial {i+1} failed: variant={action.variant}, return_code={rr.return_code}")
logging.debug(f"Command: {rr.cmd}")
if rr.stderr:
logging.debug(f"Stderr (last 500 chars): {rr.stderr[-500:]}")
else:
logging.info(f"Trial {i+1} succeeded: variant={action.variant}, runtime={rr.runtime_sec:.2f}s")
if fidelity in expensive_set:
expensive_used += 1
# always try to generate metadata JSON (avoid triggering full-flow when not needed)
meta_target = (
cfg.flow.targets.get("metadata_generate")
or cfg.flow.targets.get("metadata-generate")
or cfg.flow.targets.get("metadata", "metadata")
)
if meta_target == "metadata":
meta_target = "metadata-generate"
logging.debug(f"Generating metadata for variant {action.variant} using target={meta_target}")
meta_result = runner.run_make(
target=meta_target,
design_config=cfg.design.design_config,
flow_variant=action.variant,
overrides={},
timeout_sec=args.timeout,
)
if meta_result.return_code != 0:
logging.warning(f"Metadata generation failed for variant {action.variant}: return_code={meta_result.return_code}")
meta_path = find_best_metadata_json(
orfs_flow_dir=orfs_flow_dir,
platform=cfg.design.platform,
design=cfg.design.design,
variant=action.variant,
)
reward = None
flat = None
if meta_path:
logging.debug(f"Found metadata at: {meta_path}")
try:
mobj = load_json(meta_path)
reward, comps, flat = compute_reward(
metrics_obj=mobj,
wns_candidates=cfg.reward.wns_candidates,
area_candidates=cfg.reward.area_candidates,
power_candidates=cfg.reward.power_candidates,
weights=cfg.reward.weights,
)
if reward is not None:
logging.info(f"Computed reward for variant {action.variant}: {reward:.4f} "
f"(WNS={comps.wns}, area={comps.area}, power={comps.power})")
else:
logging.warning(f"Reward computation returned None for variant {action.variant}")
except Exception as e:
logging.error(f"Failed to compute reward for variant {action.variant}: {e}", exc_info=True)
ok = False
else:
logging.warning(f"Metadata not found for variant {action.variant} at "
f"reports/{cfg.design.platform}/{cfg.design.design}/{action.variant}/")
store.add(
TrialRecord(
exp_name=cfg.experiment.name,
platform=cfg.design.platform,
design=cfg.design.design,
variant=action.variant,
fidelity=action.fidelity,
knobs=action.knobs,
make_cmd=rr.cmd,
return_code=rr.return_code,
runtime_sec=rr.runtime_sec,
reward=reward,
metrics=flat,
metadata_path=meta_path,
)
)
agent.observe(action, ok=ok, reward=reward, metrics_flat=flat)
store.close()
# Export summary
logging.info("Exporting trial summary...")
df = export_trials(cfg.experiment.db_path)
out_csv = os.path.join(cfg.experiment.out_dir, "summary.csv")
df.to_csv(out_csv, index=False)
# Log summary statistics
total_trials = len(df)
successful = len(df[df['return_code'] == 0])
with_rewards = len(df[df['reward'].notna()])
logging.info(f"Experiment complete: {total_trials} trials, {successful} successful, {with_rewards} with rewards")
print(f"[done] wrote {out_csv}")
def cmd_analyze(args: argparse.Namespace) -> None:
df = export_trials(args.db)
ensure_dir(args.out)
df.to_csv(os.path.join(args.out, "trials.csv"), index=False)
make_plots(df, args.out)
print(f"[done] wrote plots to {args.out}")
def main() -> None:
p = argparse.ArgumentParser(prog="edgeeda")
sub = p.add_subparsers(dest="cmd", required=True)
p_tune = sub.add_parser("tune", help="Run agentic tuning loop on ORFS")
p_tune.add_argument("--config", required=True, help="YAML config")
p_tune.add_argument("--budget", type=int, default=None, help="Override total_actions")
p_tune.add_argument("--timeout", type=int, default=None, help="Timeout per make run (sec)")
p_tune.set_defaults(func=cmd_tune)
p_an = sub.add_parser("analyze", help="Export CSV + plots")
p_an.add_argument("--db", required=True, help="SQLite db path")
p_an.add_argument("--out", required=True, help="Output directory for plots")
p_an.set_defaults(func=cmd_analyze)
args = p.parse_args()
args.func(args)
if __name__ == "__main__":
main()