|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
import os |
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
from tqdm import tqdm |
|
|
|
|
|
from edgeeda.config import load_config, Config |
|
|
from edgeeda.utils import seed_everything, ensure_dir |
|
|
from edgeeda.store import TrialStore, TrialRecord |
|
|
from edgeeda.orfs.runner import ORFSRunner |
|
|
from edgeeda.orfs.metrics import find_best_metadata_json, load_json |
|
|
from edgeeda.reward import compute_reward |
|
|
from edgeeda.viz import export_trials, make_plots |
|
|
|
|
|
from edgeeda.agents.random_search import RandomSearchAgent |
|
|
from edgeeda.agents.successive_halving import SuccessiveHalvingAgent |
|
|
from edgeeda.agents.surrogate_ucb import SurrogateUCBAgent |
|
|
|
|
|
|
|
|
AGENTS = { |
|
|
"random": RandomSearchAgent, |
|
|
"successive_halving": SuccessiveHalvingAgent, |
|
|
"surrogate_ucb": SurrogateUCBAgent, |
|
|
} |
|
|
|
|
|
|
|
|
def _select_agent(cfg: Config): |
|
|
name = cfg.tuning.agent |
|
|
if name not in AGENTS: |
|
|
raise ValueError(f"Unknown agent: {name}. Choose from {list(AGENTS.keys())}") |
|
|
return AGENTS[name](cfg) |
|
|
|
|
|
|
|
|
def _setup_logging(cfg: Config) -> None: |
|
|
"""Setup logging to both file and console.""" |
|
|
log_dir = cfg.experiment.out_dir |
|
|
ensure_dir(log_dir) |
|
|
log_file = os.path.join(log_dir, "tuning.log") |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
handlers=[ |
|
|
logging.FileHandler(log_file), |
|
|
logging.StreamHandler() |
|
|
] |
|
|
) |
|
|
logging.info(f"Logging initialized. Log file: {log_file}") |
|
|
|
|
|
|
|
|
def cmd_tune(args: argparse.Namespace) -> None: |
|
|
cfg = load_config(args.config) |
|
|
if args.budget is not None: |
|
|
cfg.tuning.budget.total_actions = int(args.budget) |
|
|
|
|
|
seed_everything(cfg.experiment.seed) |
|
|
ensure_dir(cfg.experiment.out_dir) |
|
|
_setup_logging(cfg) |
|
|
|
|
|
logging.info(f"Starting tuning experiment: {cfg.experiment.name}") |
|
|
logging.info(f"Agent: {cfg.tuning.agent}, Budget: {cfg.tuning.budget.total_actions} actions") |
|
|
logging.info(f"Platform: {cfg.design.platform}, Design: {cfg.design.design}") |
|
|
|
|
|
orfs_flow_dir = cfg.experiment.orfs_flow_dir or os.environ.get("ORFS_FLOW_DIR") |
|
|
if not orfs_flow_dir: |
|
|
raise RuntimeError("ORFS flow dir missing. Set experiment.orfs_flow_dir or export ORFS_FLOW_DIR=/path/to/ORFS/flow") |
|
|
|
|
|
logging.info(f"ORFS flow directory: {orfs_flow_dir}") |
|
|
|
|
|
runner = ORFSRunner(orfs_flow_dir) |
|
|
store = TrialStore(cfg.experiment.db_path) |
|
|
agent = _select_agent(cfg) |
|
|
|
|
|
expensive_set = set(cfg.flow.fidelities[-1:]) |
|
|
expensive_used = 0 |
|
|
|
|
|
for i in tqdm(range(cfg.tuning.budget.total_actions), desc="actions"): |
|
|
action = agent.propose() |
|
|
fidelity = action.fidelity |
|
|
|
|
|
|
|
|
if fidelity in expensive_set and expensive_used >= cfg.tuning.budget.max_expensive: |
|
|
|
|
|
fidelity = cfg.flow.fidelities[0] |
|
|
action = type(action)(variant=action.variant, fidelity=fidelity, knobs=action.knobs) |
|
|
|
|
|
make_target = cfg.flow.targets.get(fidelity, fidelity) |
|
|
logging.info(f"Action {i+1}/{cfg.tuning.budget.total_actions}: variant={action.variant}, " |
|
|
f"fidelity={action.fidelity}, knobs={action.knobs}") |
|
|
|
|
|
|
|
|
logging.debug(f"Running: {make_target} for variant {action.variant}") |
|
|
rr = runner.run_make( |
|
|
target=make_target, |
|
|
design_config=cfg.design.design_config, |
|
|
flow_variant=action.variant, |
|
|
overrides={k: str(v) for k, v in action.knobs.items()}, |
|
|
timeout_sec=args.timeout, |
|
|
) |
|
|
|
|
|
ok = (rr.return_code == 0) |
|
|
if not ok: |
|
|
logging.warning(f"Trial {i+1} failed: variant={action.variant}, return_code={rr.return_code}") |
|
|
logging.debug(f"Command: {rr.cmd}") |
|
|
if rr.stderr: |
|
|
logging.debug(f"Stderr (last 500 chars): {rr.stderr[-500:]}") |
|
|
else: |
|
|
logging.info(f"Trial {i+1} succeeded: variant={action.variant}, runtime={rr.runtime_sec:.2f}s") |
|
|
|
|
|
if fidelity in expensive_set: |
|
|
expensive_used += 1 |
|
|
|
|
|
|
|
|
meta_target = ( |
|
|
cfg.flow.targets.get("metadata_generate") |
|
|
or cfg.flow.targets.get("metadata-generate") |
|
|
or cfg.flow.targets.get("metadata", "metadata") |
|
|
) |
|
|
if meta_target == "metadata": |
|
|
meta_target = "metadata-generate" |
|
|
logging.debug(f"Generating metadata for variant {action.variant} using target={meta_target}") |
|
|
meta_result = runner.run_make( |
|
|
target=meta_target, |
|
|
design_config=cfg.design.design_config, |
|
|
flow_variant=action.variant, |
|
|
overrides={}, |
|
|
timeout_sec=args.timeout, |
|
|
) |
|
|
if meta_result.return_code != 0: |
|
|
logging.warning(f"Metadata generation failed for variant {action.variant}: return_code={meta_result.return_code}") |
|
|
|
|
|
meta_path = find_best_metadata_json( |
|
|
orfs_flow_dir=orfs_flow_dir, |
|
|
platform=cfg.design.platform, |
|
|
design=cfg.design.design, |
|
|
variant=action.variant, |
|
|
) |
|
|
|
|
|
reward = None |
|
|
flat = None |
|
|
|
|
|
if meta_path: |
|
|
logging.debug(f"Found metadata at: {meta_path}") |
|
|
try: |
|
|
mobj = load_json(meta_path) |
|
|
reward, comps, flat = compute_reward( |
|
|
metrics_obj=mobj, |
|
|
wns_candidates=cfg.reward.wns_candidates, |
|
|
area_candidates=cfg.reward.area_candidates, |
|
|
power_candidates=cfg.reward.power_candidates, |
|
|
weights=cfg.reward.weights, |
|
|
) |
|
|
if reward is not None: |
|
|
logging.info(f"Computed reward for variant {action.variant}: {reward:.4f} " |
|
|
f"(WNS={comps.wns}, area={comps.area}, power={comps.power})") |
|
|
else: |
|
|
logging.warning(f"Reward computation returned None for variant {action.variant}") |
|
|
except Exception as e: |
|
|
logging.error(f"Failed to compute reward for variant {action.variant}: {e}", exc_info=True) |
|
|
ok = False |
|
|
else: |
|
|
logging.warning(f"Metadata not found for variant {action.variant} at " |
|
|
f"reports/{cfg.design.platform}/{cfg.design.design}/{action.variant}/") |
|
|
|
|
|
store.add( |
|
|
TrialRecord( |
|
|
exp_name=cfg.experiment.name, |
|
|
platform=cfg.design.platform, |
|
|
design=cfg.design.design, |
|
|
variant=action.variant, |
|
|
fidelity=action.fidelity, |
|
|
knobs=action.knobs, |
|
|
make_cmd=rr.cmd, |
|
|
return_code=rr.return_code, |
|
|
runtime_sec=rr.runtime_sec, |
|
|
reward=reward, |
|
|
metrics=flat, |
|
|
metadata_path=meta_path, |
|
|
) |
|
|
) |
|
|
|
|
|
agent.observe(action, ok=ok, reward=reward, metrics_flat=flat) |
|
|
|
|
|
store.close() |
|
|
|
|
|
|
|
|
logging.info("Exporting trial summary...") |
|
|
df = export_trials(cfg.experiment.db_path) |
|
|
out_csv = os.path.join(cfg.experiment.out_dir, "summary.csv") |
|
|
df.to_csv(out_csv, index=False) |
|
|
|
|
|
|
|
|
total_trials = len(df) |
|
|
successful = len(df[df['return_code'] == 0]) |
|
|
with_rewards = len(df[df['reward'].notna()]) |
|
|
logging.info(f"Experiment complete: {total_trials} trials, {successful} successful, {with_rewards} with rewards") |
|
|
|
|
|
print(f"[done] wrote {out_csv}") |
|
|
|
|
|
|
|
|
def cmd_analyze(args: argparse.Namespace) -> None: |
|
|
df = export_trials(args.db) |
|
|
ensure_dir(args.out) |
|
|
df.to_csv(os.path.join(args.out, "trials.csv"), index=False) |
|
|
make_plots(df, args.out) |
|
|
print(f"[done] wrote plots to {args.out}") |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
p = argparse.ArgumentParser(prog="edgeeda") |
|
|
sub = p.add_subparsers(dest="cmd", required=True) |
|
|
|
|
|
p_tune = sub.add_parser("tune", help="Run agentic tuning loop on ORFS") |
|
|
p_tune.add_argument("--config", required=True, help="YAML config") |
|
|
p_tune.add_argument("--budget", type=int, default=None, help="Override total_actions") |
|
|
p_tune.add_argument("--timeout", type=int, default=None, help="Timeout per make run (sec)") |
|
|
p_tune.set_defaults(func=cmd_tune) |
|
|
|
|
|
p_an = sub.add_parser("analyze", help="Export CSV + plots") |
|
|
p_an.add_argument("--db", required=True, help="SQLite db path") |
|
|
p_an.add_argument("--out", required=True, help="Output directory for plots") |
|
|
p_an.set_defaults(func=cmd_analyze) |
|
|
|
|
|
args = p.parse_args() |
|
|
args.func(args) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|