from __future__ import annotations import json import sqlite3 from dataclasses import asdict, dataclass from typing import Any, Dict, Optional from edgeeda.utils import ensure_dir, now_ts SCHEMA = """ CREATE TABLE IF NOT EXISTS trials ( id INTEGER PRIMARY KEY AUTOINCREMENT, exp_name TEXT NOT NULL, platform TEXT NOT NULL, design TEXT NOT NULL, variant TEXT NOT NULL, fidelity TEXT NOT NULL, knobs_json TEXT NOT NULL, make_cmd TEXT NOT NULL, return_code INTEGER NOT NULL, runtime_sec REAL NOT NULL, reward REAL, metrics_json TEXT, metadata_path TEXT, created_ts REAL NOT NULL ); CREATE INDEX IF NOT EXISTS idx_trials_exp ON trials(exp_name); CREATE INDEX IF NOT EXISTS idx_trials_variant ON trials(platform, design, variant); """ @dataclass class TrialRecord: exp_name: str platform: str design: str variant: str fidelity: str knobs: Dict[str, Any] make_cmd: str return_code: int runtime_sec: float reward: Optional[float] metrics: Optional[Dict[str, Any]] metadata_path: Optional[str] class TrialStore: def __init__(self, db_path: str): ensure_dir(db_path.rsplit("/", 1)[0] if "/" in db_path else ".") self.conn = sqlite3.connect(db_path) self.conn.execute("PRAGMA journal_mode=WAL;") self.conn.executescript(SCHEMA) self.conn.commit() def add(self, r: TrialRecord) -> None: self.conn.execute( """ INSERT INTO trials( exp_name, platform, design, variant, fidelity, knobs_json, make_cmd, return_code, runtime_sec, reward, metrics_json, metadata_path, created_ts ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( r.exp_name, r.platform, r.design, r.variant, r.fidelity, json.dumps(r.knobs, sort_keys=True), r.make_cmd, int(r.return_code), float(r.runtime_sec), None if r.reward is None else float(r.reward), None if r.metrics is None else json.dumps(r.metrics), r.metadata_path, now_ts(), ), ) self.conn.commit() def fetch_all(self, exp_name: str): cur = self.conn.execute( "SELECT exp_name, platform, design, variant, fidelity, knobs_json, make_cmd, return_code, runtime_sec, reward, metrics_json, metadata_path, created_ts FROM trials WHERE exp_name=? ORDER BY id ASC", (exp_name,), ) return cur.fetchall() def close(self) -> None: self.conn.close()