"""Utility functions for the experiment pipeline.""" import os import json import time import uuid import hashlib import numpy as np from datetime import datetime from dataclasses import dataclass, field, asdict from typing import Optional, List, Dict, Any, Tuple @dataclass class FitResult: """Container for model fitting results.""" params: Dict[str, np.ndarray] objective_trace: List[float] = field(default_factory=list) n_iterations: int = 0 converged: bool = False runtime_sec: float = 0.0 config_id: str = "" model_family: str = "poisson_gamma" inference_type: str = "vi" likelihood: str = "poisson" prior: str = "gamma" diagnostics: Dict[str, Any] = field(default_factory=dict) def to_dict(self): d = { 'n_iterations': self.n_iterations, 'converged': self.converged, 'runtime_sec': self.runtime_sec, 'config_id': self.config_id, 'model_family': self.model_family, 'inference_type': self.inference_type, 'likelihood': self.likelihood, 'prior': self.prior, 'objective_trace_len': len(self.objective_trace), 'final_objective': self.objective_trace[-1] if self.objective_trace else None, } d.update(self.diagnostics) return d def generate_run_id(): return datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:8] def generate_config_id(config: dict) -> str: s = json.dumps(config, sort_keys=True, default=str) return hashlib.md5(s.encode()).hexdigest()[:12] def save_jsonl(records: list, filepath: str): os.makedirs(os.path.dirname(filepath), exist_ok=True) with open(filepath, 'a') as f: for rec in records: line = json.dumps(rec, default=_json_default) f.write(line + '\n') def load_jsonl(filepath: str) -> list: records = [] with open(filepath) as f: for line in f: line = line.strip() if line: records.append(json.loads(line)) return records def _json_default(obj): if isinstance(obj, np.integer): return int(obj) if isinstance(obj, np.floating): return float(obj) if isinstance(obj, np.ndarray): return obj.tolist() if isinstance(obj, np.bool_): return bool(obj) return str(obj) def save_sidecar_json(filepath: str, metadata: dict): sidecar_path = filepath.rsplit('.', 1)[0] + '_meta.json' with open(sidecar_path, 'w') as f: json.dump(metadata, f, indent=2, default=_json_default) def check_positive(arr, name="array"): if np.any(arr <= 0): raise ValueError(f"{name} contains non-positive values: min={np.min(arr)}") if np.any(np.isnan(arr)): raise ValueError(f"{name} contains NaN values") if np.any(np.isinf(arr)): raise ValueError(f"{name} contains infinite values") def relative_param_change(old_params, new_params): """Compute max relative parameter change across all blocks.""" max_change = 0.0 for key in old_params: if key in new_params: old = old_params[key] new = new_params[key] change = np.max(np.abs(new - old) / (1.0 + np.abs(old))) max_change = max(max_change, change) return max_change def stable_softmax(logits, axis=-1): """Numerically stable softmax.""" logits = logits - np.max(logits, axis=axis, keepdims=True) e = np.exp(logits) return e / np.sum(e, axis=axis, keepdims=True) def ensure_dir(path): os.makedirs(path, exist_ok=True) return path