File size: 1,910 Bytes
aa677e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from __future__ import annotations

import os
import json
import sqlite3
from typing import Any, Dict

import pandas as pd
import matplotlib.pyplot as plt

from edgeeda.utils import ensure_dir


def export_trials(db_path: str) -> pd.DataFrame:
    con = sqlite3.connect(db_path)
    df = pd.read_sql_query("SELECT * FROM trials", con)
    con.close()
    return df


def make_plots(df: pd.DataFrame, out_dir: str) -> None:
    ensure_dir(out_dir)

    # Learning curve: best reward over time
    df2 = df.copy()
    df2["reward"] = pd.to_numeric(df2["reward"], errors="coerce")
    df2 = df2.dropna(subset=["reward"]).sort_values("id")
    if not df2.empty:
        best = df2["reward"].cummax()
        plt.figure()
        plt.plot(df2["id"].values, best.values)
        plt.xlabel("trial id")
        plt.ylabel("best reward so far")
        plt.tight_layout()
        plt.savefig(os.path.join(out_dir, "learning_curve.png"), dpi=200)
        plt.close()

    # Pareto-ish scatter: area vs wns from metrics_json if available
    areas, wnss = [], []
    for _, r in df.iterrows():
        mj = r.get("metrics_json")
        if not isinstance(mj, str) or not mj:
            continue
        try:
            m = json.loads(mj)
        except Exception:
            continue
        # try common keys (flattened already stored by runner)
        a = m.get("design__die__area") or m.get("finish__design__die__area")
        w = m.get("timing__setup__wns") or m.get("finish__timing__setup__wns")
        if a is None or w is None:
            continue
        try:
            areas.append(float(a)); wnss.append(float(w))
        except Exception:
            pass

    if areas:
        plt.figure()
        plt.scatter(areas, wnss)
        plt.xlabel("die area")
        plt.ylabel("WNS")
        plt.tight_layout()
        plt.savefig(os.path.join(out_dir, "area_vs_wns.png"), dpi=200)
        plt.close()