|
|
from __future__ import annotations |
|
|
|
|
|
import json |
|
|
import sqlite3 |
|
|
from dataclasses import asdict, dataclass |
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
from edgeeda.utils import ensure_dir, now_ts |
|
|
|
|
|
|
|
|
SCHEMA = """ |
|
|
CREATE TABLE IF NOT EXISTS trials ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
exp_name TEXT NOT NULL, |
|
|
platform TEXT NOT NULL, |
|
|
design TEXT NOT NULL, |
|
|
variant TEXT NOT NULL, |
|
|
fidelity TEXT NOT NULL, |
|
|
|
|
|
knobs_json TEXT NOT NULL, |
|
|
make_cmd TEXT NOT NULL, |
|
|
return_code INTEGER NOT NULL, |
|
|
runtime_sec REAL NOT NULL, |
|
|
|
|
|
reward REAL, |
|
|
metrics_json TEXT, |
|
|
metadata_path TEXT, |
|
|
|
|
|
created_ts REAL NOT NULL |
|
|
); |
|
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_trials_exp ON trials(exp_name); |
|
|
CREATE INDEX IF NOT EXISTS idx_trials_variant ON trials(platform, design, variant); |
|
|
""" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrialRecord: |
|
|
exp_name: str |
|
|
platform: str |
|
|
design: str |
|
|
variant: str |
|
|
fidelity: str |
|
|
knobs: Dict[str, Any] |
|
|
make_cmd: str |
|
|
return_code: int |
|
|
runtime_sec: float |
|
|
reward: Optional[float] |
|
|
metrics: Optional[Dict[str, Any]] |
|
|
metadata_path: Optional[str] |
|
|
|
|
|
|
|
|
class TrialStore: |
|
|
def __init__(self, db_path: str): |
|
|
ensure_dir(db_path.rsplit("/", 1)[0] if "/" in db_path else ".") |
|
|
self.conn = sqlite3.connect(db_path) |
|
|
self.conn.execute("PRAGMA journal_mode=WAL;") |
|
|
self.conn.executescript(SCHEMA) |
|
|
self.conn.commit() |
|
|
|
|
|
def add(self, r: TrialRecord) -> None: |
|
|
self.conn.execute( |
|
|
""" |
|
|
INSERT INTO trials( |
|
|
exp_name, platform, design, variant, fidelity, |
|
|
knobs_json, make_cmd, return_code, runtime_sec, |
|
|
reward, metrics_json, metadata_path, created_ts |
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) |
|
|
""", |
|
|
( |
|
|
r.exp_name, |
|
|
r.platform, |
|
|
r.design, |
|
|
r.variant, |
|
|
r.fidelity, |
|
|
json.dumps(r.knobs, sort_keys=True), |
|
|
r.make_cmd, |
|
|
int(r.return_code), |
|
|
float(r.runtime_sec), |
|
|
None if r.reward is None else float(r.reward), |
|
|
None if r.metrics is None else json.dumps(r.metrics), |
|
|
r.metadata_path, |
|
|
now_ts(), |
|
|
), |
|
|
) |
|
|
self.conn.commit() |
|
|
|
|
|
def fetch_all(self, exp_name: str): |
|
|
cur = self.conn.execute( |
|
|
"SELECT exp_name, platform, design, variant, fidelity, knobs_json, make_cmd, return_code, runtime_sec, reward, metrics_json, metadata_path, created_ts FROM trials WHERE exp_name=? ORDER BY id ASC", |
|
|
(exp_name,), |
|
|
) |
|
|
return cur.fetchall() |
|
|
|
|
|
def close(self) -> None: |
|
|
self.conn.close() |
|
|
|