File size: 3,490 Bytes
1c8c60e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import csv
import json
import os
import pickle
from collections import Counter
from copy import deepcopy
from locale import strcoll
from statistics import mean
from typing import Any, Dict, Iterator, List, Optional, Tuple, TypedDict

import matplotlib.pyplot as plt
import numpy as np
from torch.utils.tensorboard import SummaryWriter

plt.style.use(
    "https://raw.githubusercontent.com/dereckpiche/DedeStyle/refs/heads/main/dedestyle.mplstyle"
)

import wandb

from . import wandb_utils


class StatPack:
    def __init__(self):
        self.data = {}

    def add_stat(self, key: str, value: float | int | None):
        assert (
            isinstance(value, float) or isinstance(value, int) or value is None
        ), f"Value {value} is not a valid type"
        if key not in self.data:
            self.data[key] = []
        self.data[key].append(value)

    def add_stats(self, other: "StatPack"):
        for key in other.keys():
            self.add_stat(key, other[key])

    def __getitem__(self, key: str):
        return self.data[key]

    def __setitem__(self, key: str, value: Any):
        self.data[key] = value

    def __contains__(self, key: str):
        return key in self.data

    def __len__(self):
        return len(self.data)

    def __iter__(self):
        return iter(self.data)

    def keys(self):
        return self.data.keys()

    def values(self):
        return self.data.values()

    def items(self):
        return self.data.items()

    def mean(self):
        mean_st = StatPack()
        for key in self.keys():
            if isinstance(self[key], list):
                # TODO: exclude None values
                non_none_values = [v for v in self[key] if v is not None]
                if non_none_values:
                    mean_st[key] = np.mean(np.array(non_none_values))
                else:
                    mean_st[key] = None
        return mean_st

    def store_plots(self, folder: str):
        os.makedirs(folder, exist_ok=True)
        for key in self.keys():
            plt.figure(figsize=(10, 5))
            plt.plot(self[key])
            plt.title(key)
            plt.savefig(os.path.join(folder, f"{key}.pdf"))
            plt.close()

    def store_numpy(self, folder: str):
        os.makedirs(folder, exist_ok=True)
        for key in self.keys():
            # Sanitize filename components (avoid slashes, spaces, etc.)
            safe_key = str(key).replace(os.sep, "_").replace("/", "_").replace(" ", "_")
            values = self[key]
            # Convert None to NaN for numpy compatibility
            arr = np.array(
                [(np.nan if (v is None) else v) for v in values], dtype=float
            )
            np.save(os.path.join(folder, f"{safe_key}.npy"), arr)

    def store_json(self, folder: str, filename: str = "stats.json"):
        os.makedirs(folder, exist_ok=True)
        with open(os.path.join(folder, filename), "w") as f:
            json.dump(self.data, f, indent=4)

    def store_csv(self, folder: str):
        os.makedirs(folder, exist_ok=True)
        for key in self.keys():
            with open(os.path.join(folder, f"stats.csv"), "w") as f:
                writer = csv.writer(f)
                writer.writerow([key] + self[key])

    def store_pickle(self, folder: str):
        os.makedirs(folder, exist_ok=True)
        for key in self.keys():
            with open(os.path.join(folder, f"stats.pkl"), "wb") as f:
                pickle.dump(self[key], f)