File size: 2,706 Bytes
aa677e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from __future__ import annotations

import json
import sqlite3
from dataclasses import asdict, dataclass
from typing import Any, Dict, Optional

from edgeeda.utils import ensure_dir, now_ts


SCHEMA = """
CREATE TABLE IF NOT EXISTS trials (
  id INTEGER PRIMARY KEY AUTOINCREMENT,
  exp_name TEXT NOT NULL,
  platform TEXT NOT NULL,
  design TEXT NOT NULL,
  variant TEXT NOT NULL,
  fidelity TEXT NOT NULL,

  knobs_json TEXT NOT NULL,
  make_cmd TEXT NOT NULL,
  return_code INTEGER NOT NULL,
  runtime_sec REAL NOT NULL,

  reward REAL,
  metrics_json TEXT,
  metadata_path TEXT,

  created_ts REAL NOT NULL
);

CREATE INDEX IF NOT EXISTS idx_trials_exp ON trials(exp_name);
CREATE INDEX IF NOT EXISTS idx_trials_variant ON trials(platform, design, variant);
"""


@dataclass
class TrialRecord:
    exp_name: str
    platform: str
    design: str
    variant: str
    fidelity: str
    knobs: Dict[str, Any]
    make_cmd: str
    return_code: int
    runtime_sec: float
    reward: Optional[float]
    metrics: Optional[Dict[str, Any]]
    metadata_path: Optional[str]


class TrialStore:
    def __init__(self, db_path: str):
        ensure_dir(db_path.rsplit("/", 1)[0] if "/" in db_path else ".")
        self.conn = sqlite3.connect(db_path)
        self.conn.execute("PRAGMA journal_mode=WAL;")
        self.conn.executescript(SCHEMA)
        self.conn.commit()

    def add(self, r: TrialRecord) -> None:
        self.conn.execute(
            """
            INSERT INTO trials(
              exp_name, platform, design, variant, fidelity,
              knobs_json, make_cmd, return_code, runtime_sec,
              reward, metrics_json, metadata_path, created_ts
            ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            """,
            (
                r.exp_name,
                r.platform,
                r.design,
                r.variant,
                r.fidelity,
                json.dumps(r.knobs, sort_keys=True),
                r.make_cmd,
                int(r.return_code),
                float(r.runtime_sec),
                None if r.reward is None else float(r.reward),
                None if r.metrics is None else json.dumps(r.metrics),
                r.metadata_path,
                now_ts(),
            ),
        )
        self.conn.commit()

    def fetch_all(self, exp_name: str):
        cur = self.conn.execute(
            "SELECT exp_name, platform, design, variant, fidelity, knobs_json, make_cmd, return_code, runtime_sec, reward, metrics_json, metadata_path, created_ts FROM trials WHERE exp_name=? ORDER BY id ASC",
            (exp_name,),
        )
        return cur.fetchall()

    def close(self) -> None:
        self.conn.close()