| | import csv |
| | import json |
| | import os |
| | import pickle |
| | from collections import Counter |
| | from copy import deepcopy |
| | from locale import strcoll |
| | from statistics import mean |
| | from typing import Any, Dict, Iterator, List, Optional, Tuple, TypedDict |
| |
|
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | from torch.utils.tensorboard import SummaryWriter |
| |
|
| | plt.style.use( |
| | "https://raw.githubusercontent.com/dereckpiche/DedeStyle/refs/heads/main/dedestyle.mplstyle" |
| | ) |
| |
|
| | import wandb |
| |
|
| | from . import wandb_utils |
| |
|
| |
|
| | class StatPack: |
| | def __init__(self): |
| | self.data = {} |
| |
|
| | def add_stat(self, key: str, value: float | int | None): |
| | assert ( |
| | isinstance(value, float) or isinstance(value, int) or value is None |
| | ), f"Value {value} is not a valid type" |
| | if key not in self.data: |
| | self.data[key] = [] |
| | self.data[key].append(value) |
| |
|
| | def add_stats(self, other: "StatPack"): |
| | for key in other.keys(): |
| | self.add_stat(key, other[key]) |
| |
|
| | def __getitem__(self, key: str): |
| | return self.data[key] |
| |
|
| | def __setitem__(self, key: str, value: Any): |
| | self.data[key] = value |
| |
|
| | def __contains__(self, key: str): |
| | return key in self.data |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| | def __iter__(self): |
| | return iter(self.data) |
| |
|
| | def keys(self): |
| | return self.data.keys() |
| |
|
| | def values(self): |
| | return self.data.values() |
| |
|
| | def items(self): |
| | return self.data.items() |
| |
|
| | def mean(self): |
| | mean_st = StatPack() |
| | for key in self.keys(): |
| | if isinstance(self[key], list): |
| | |
| | non_none_values = [v for v in self[key] if v is not None] |
| | if non_none_values: |
| | mean_st[key] = np.mean(np.array(non_none_values)) |
| | else: |
| | mean_st[key] = None |
| | return mean_st |
| |
|
| | def store_plots(self, folder: str): |
| | os.makedirs(folder, exist_ok=True) |
| | for key in self.keys(): |
| | plt.figure(figsize=(10, 5)) |
| | plt.plot(self[key]) |
| | plt.title(key) |
| | plt.savefig(os.path.join(folder, f"{key}.pdf")) |
| | plt.close() |
| |
|
| | def store_numpy(self, folder: str): |
| | os.makedirs(folder, exist_ok=True) |
| | for key in self.keys(): |
| | |
| | safe_key = str(key).replace(os.sep, "_").replace("/", "_").replace(" ", "_") |
| | values = self[key] |
| | |
| | arr = np.array( |
| | [(np.nan if (v is None) else v) for v in values], dtype=float |
| | ) |
| | np.save(os.path.join(folder, f"{safe_key}.npy"), arr) |
| |
|
| | def store_json(self, folder: str, filename: str = "stats.json"): |
| | os.makedirs(folder, exist_ok=True) |
| | with open(os.path.join(folder, filename), "w") as f: |
| | json.dump(self.data, f, indent=4) |
| |
|
| | def store_csv(self, folder: str): |
| | os.makedirs(folder, exist_ok=True) |
| | for key in self.keys(): |
| | with open(os.path.join(folder, f"stats.csv"), "w") as f: |
| | writer = csv.writer(f) |
| | writer.writerow([key] + self[key]) |
| |
|
| | def store_pickle(self, folder: str): |
| | os.makedirs(folder, exist_ok=True) |
| | for key in self.keys(): |
| | with open(os.path.join(folder, f"stats.pkl"), "wb") as f: |
| | pickle.dump(self[key], f) |
| |
|