Muqeeth's picture
Add files using upload-large-folder tool
1c8c60e verified
import csv
import json
import os
import pickle
from collections import Counter
from copy import deepcopy
from locale import strcoll
from statistics import mean
from typing import Any, Dict, Iterator, List, Optional, Tuple, TypedDict
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.tensorboard import SummaryWriter
plt.style.use(
"https://raw.githubusercontent.com/dereckpiche/DedeStyle/refs/heads/main/dedestyle.mplstyle"
)
import wandb
from . import wandb_utils
class StatPack:
def __init__(self):
self.data = {}
def add_stat(self, key: str, value: float | int | None):
assert (
isinstance(value, float) or isinstance(value, int) or value is None
), f"Value {value} is not a valid type"
if key not in self.data:
self.data[key] = []
self.data[key].append(value)
def add_stats(self, other: "StatPack"):
for key in other.keys():
self.add_stat(key, other[key])
def __getitem__(self, key: str):
return self.data[key]
def __setitem__(self, key: str, value: Any):
self.data[key] = value
def __contains__(self, key: str):
return key in self.data
def __len__(self):
return len(self.data)
def __iter__(self):
return iter(self.data)
def keys(self):
return self.data.keys()
def values(self):
return self.data.values()
def items(self):
return self.data.items()
def mean(self):
mean_st = StatPack()
for key in self.keys():
if isinstance(self[key], list):
# TODO: exclude None values
non_none_values = [v for v in self[key] if v is not None]
if non_none_values:
mean_st[key] = np.mean(np.array(non_none_values))
else:
mean_st[key] = None
return mean_st
def store_plots(self, folder: str):
os.makedirs(folder, exist_ok=True)
for key in self.keys():
plt.figure(figsize=(10, 5))
plt.plot(self[key])
plt.title(key)
plt.savefig(os.path.join(folder, f"{key}.pdf"))
plt.close()
def store_numpy(self, folder: str):
os.makedirs(folder, exist_ok=True)
for key in self.keys():
# Sanitize filename components (avoid slashes, spaces, etc.)
safe_key = str(key).replace(os.sep, "_").replace("/", "_").replace(" ", "_")
values = self[key]
# Convert None to NaN for numpy compatibility
arr = np.array(
[(np.nan if (v is None) else v) for v in values], dtype=float
)
np.save(os.path.join(folder, f"{safe_key}.npy"), arr)
def store_json(self, folder: str, filename: str = "stats.json"):
os.makedirs(folder, exist_ok=True)
with open(os.path.join(folder, filename), "w") as f:
json.dump(self.data, f, indent=4)
def store_csv(self, folder: str):
os.makedirs(folder, exist_ok=True)
for key in self.keys():
with open(os.path.join(folder, f"stats.csv"), "w") as f:
writer = csv.writer(f)
writer.writerow([key] + self[key])
def store_pickle(self, folder: str):
os.makedirs(folder, exist_ok=True)
for key in self.keys():
with open(os.path.join(folder, f"stats.pkl"), "wb") as f:
pickle.dump(self[key], f)