SamChYe's picture
Publish EdgeEDA agent
aa677e3 verified
from __future__ import annotations
import json
import sqlite3
from dataclasses import asdict, dataclass
from typing import Any, Dict, Optional
from edgeeda.utils import ensure_dir, now_ts
SCHEMA = """
CREATE TABLE IF NOT EXISTS trials (
id INTEGER PRIMARY KEY AUTOINCREMENT,
exp_name TEXT NOT NULL,
platform TEXT NOT NULL,
design TEXT NOT NULL,
variant TEXT NOT NULL,
fidelity TEXT NOT NULL,
knobs_json TEXT NOT NULL,
make_cmd TEXT NOT NULL,
return_code INTEGER NOT NULL,
runtime_sec REAL NOT NULL,
reward REAL,
metrics_json TEXT,
metadata_path TEXT,
created_ts REAL NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_trials_exp ON trials(exp_name);
CREATE INDEX IF NOT EXISTS idx_trials_variant ON trials(platform, design, variant);
"""
@dataclass
class TrialRecord:
exp_name: str
platform: str
design: str
variant: str
fidelity: str
knobs: Dict[str, Any]
make_cmd: str
return_code: int
runtime_sec: float
reward: Optional[float]
metrics: Optional[Dict[str, Any]]
metadata_path: Optional[str]
class TrialStore:
def __init__(self, db_path: str):
ensure_dir(db_path.rsplit("/", 1)[0] if "/" in db_path else ".")
self.conn = sqlite3.connect(db_path)
self.conn.execute("PRAGMA journal_mode=WAL;")
self.conn.executescript(SCHEMA)
self.conn.commit()
def add(self, r: TrialRecord) -> None:
self.conn.execute(
"""
INSERT INTO trials(
exp_name, platform, design, variant, fidelity,
knobs_json, make_cmd, return_code, runtime_sec,
reward, metrics_json, metadata_path, created_ts
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
r.exp_name,
r.platform,
r.design,
r.variant,
r.fidelity,
json.dumps(r.knobs, sort_keys=True),
r.make_cmd,
int(r.return_code),
float(r.runtime_sec),
None if r.reward is None else float(r.reward),
None if r.metrics is None else json.dumps(r.metrics),
r.metadata_path,
now_ts(),
),
)
self.conn.commit()
def fetch_all(self, exp_name: str):
cur = self.conn.execute(
"SELECT exp_name, platform, design, variant, fidelity, knobs_json, make_cmd, return_code, runtime_sec, reward, metrics_json, metadata_path, created_ts FROM trials WHERE exp_name=? ORDER BY id ASC",
(exp_name,),
)
return cur.fetchall()
def close(self) -> None:
self.conn.close()