File size: 2,706 Bytes
aa677e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from __future__ import annotations
import json
import sqlite3
from dataclasses import asdict, dataclass
from typing import Any, Dict, Optional
from edgeeda.utils import ensure_dir, now_ts
SCHEMA = """
CREATE TABLE IF NOT EXISTS trials (
id INTEGER PRIMARY KEY AUTOINCREMENT,
exp_name TEXT NOT NULL,
platform TEXT NOT NULL,
design TEXT NOT NULL,
variant TEXT NOT NULL,
fidelity TEXT NOT NULL,
knobs_json TEXT NOT NULL,
make_cmd TEXT NOT NULL,
return_code INTEGER NOT NULL,
runtime_sec REAL NOT NULL,
reward REAL,
metrics_json TEXT,
metadata_path TEXT,
created_ts REAL NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_trials_exp ON trials(exp_name);
CREATE INDEX IF NOT EXISTS idx_trials_variant ON trials(platform, design, variant);
"""
@dataclass
class TrialRecord:
exp_name: str
platform: str
design: str
variant: str
fidelity: str
knobs: Dict[str, Any]
make_cmd: str
return_code: int
runtime_sec: float
reward: Optional[float]
metrics: Optional[Dict[str, Any]]
metadata_path: Optional[str]
class TrialStore:
def __init__(self, db_path: str):
ensure_dir(db_path.rsplit("/", 1)[0] if "/" in db_path else ".")
self.conn = sqlite3.connect(db_path)
self.conn.execute("PRAGMA journal_mode=WAL;")
self.conn.executescript(SCHEMA)
self.conn.commit()
def add(self, r: TrialRecord) -> None:
self.conn.execute(
"""
INSERT INTO trials(
exp_name, platform, design, variant, fidelity,
knobs_json, make_cmd, return_code, runtime_sec,
reward, metrics_json, metadata_path, created_ts
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
r.exp_name,
r.platform,
r.design,
r.variant,
r.fidelity,
json.dumps(r.knobs, sort_keys=True),
r.make_cmd,
int(r.return_code),
float(r.runtime_sec),
None if r.reward is None else float(r.reward),
None if r.metrics is None else json.dumps(r.metrics),
r.metadata_path,
now_ts(),
),
)
self.conn.commit()
def fetch_all(self, exp_name: str):
cur = self.conn.execute(
"SELECT exp_name, platform, design, variant, fidelity, knobs_json, make_cmd, return_code, runtime_sec, reward, metrics_json, metadata_path, created_ts FROM trials WHERE exp_name=? ORDER BY id ASC",
(exp_name,),
)
return cur.fetchall()
def close(self) -> None:
self.conn.close()
|