File size: 3,625 Bytes
fbfd974 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | """Utility functions for the experiment pipeline."""
import os
import json
import time
import uuid
import hashlib
import numpy as np
from datetime import datetime
from dataclasses import dataclass, field, asdict
from typing import Optional, List, Dict, Any, Tuple
@dataclass
class FitResult:
"""Container for model fitting results."""
params: Dict[str, np.ndarray]
objective_trace: List[float] = field(default_factory=list)
n_iterations: int = 0
converged: bool = False
runtime_sec: float = 0.0
config_id: str = ""
model_family: str = "poisson_gamma"
inference_type: str = "vi"
likelihood: str = "poisson"
prior: str = "gamma"
diagnostics: Dict[str, Any] = field(default_factory=dict)
def to_dict(self):
d = {
'n_iterations': self.n_iterations,
'converged': self.converged,
'runtime_sec': self.runtime_sec,
'config_id': self.config_id,
'model_family': self.model_family,
'inference_type': self.inference_type,
'likelihood': self.likelihood,
'prior': self.prior,
'objective_trace_len': len(self.objective_trace),
'final_objective': self.objective_trace[-1] if self.objective_trace else None,
}
d.update(self.diagnostics)
return d
def generate_run_id():
return datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:8]
def generate_config_id(config: dict) -> str:
s = json.dumps(config, sort_keys=True, default=str)
return hashlib.md5(s.encode()).hexdigest()[:12]
def save_jsonl(records: list, filepath: str):
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, 'a') as f:
for rec in records:
line = json.dumps(rec, default=_json_default)
f.write(line + '\n')
def load_jsonl(filepath: str) -> list:
records = []
with open(filepath) as f:
for line in f:
line = line.strip()
if line:
records.append(json.loads(line))
return records
def _json_default(obj):
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.bool_):
return bool(obj)
return str(obj)
def save_sidecar_json(filepath: str, metadata: dict):
sidecar_path = filepath.rsplit('.', 1)[0] + '_meta.json'
with open(sidecar_path, 'w') as f:
json.dump(metadata, f, indent=2, default=_json_default)
def check_positive(arr, name="array"):
if np.any(arr <= 0):
raise ValueError(f"{name} contains non-positive values: min={np.min(arr)}")
if np.any(np.isnan(arr)):
raise ValueError(f"{name} contains NaN values")
if np.any(np.isinf(arr)):
raise ValueError(f"{name} contains infinite values")
def relative_param_change(old_params, new_params):
"""Compute max relative parameter change across all blocks."""
max_change = 0.0
for key in old_params:
if key in new_params:
old = old_params[key]
new = new_params[key]
change = np.max(np.abs(new - old) / (1.0 + np.abs(old)))
max_change = max(max_change, change)
return max_change
def stable_softmax(logits, axis=-1):
"""Numerically stable softmax."""
logits = logits - np.max(logits, axis=axis, keepdims=True)
e = np.exp(logits)
return e / np.sum(e, axis=axis, keepdims=True)
def ensure_dir(path):
os.makedirs(path, exist_ok=True)
return path
|