| """Utility functions for the experiment pipeline.""" |
| import os |
| import json |
| import time |
| import uuid |
| import hashlib |
| import numpy as np |
| from datetime import datetime |
| from dataclasses import dataclass, field, asdict |
| from typing import Optional, List, Dict, Any, Tuple |
|
|
|
|
| @dataclass |
| class FitResult: |
| """Container for model fitting results.""" |
| params: Dict[str, np.ndarray] |
| objective_trace: List[float] = field(default_factory=list) |
| n_iterations: int = 0 |
| converged: bool = False |
| runtime_sec: float = 0.0 |
| config_id: str = "" |
| model_family: str = "poisson_gamma" |
| inference_type: str = "vi" |
| likelihood: str = "poisson" |
| prior: str = "gamma" |
| diagnostics: Dict[str, Any] = field(default_factory=dict) |
|
|
| def to_dict(self): |
| d = { |
| 'n_iterations': self.n_iterations, |
| 'converged': self.converged, |
| 'runtime_sec': self.runtime_sec, |
| 'config_id': self.config_id, |
| 'model_family': self.model_family, |
| 'inference_type': self.inference_type, |
| 'likelihood': self.likelihood, |
| 'prior': self.prior, |
| 'objective_trace_len': len(self.objective_trace), |
| 'final_objective': self.objective_trace[-1] if self.objective_trace else None, |
| } |
| d.update(self.diagnostics) |
| return d |
|
|
|
|
| def generate_run_id(): |
| return datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:8] |
|
|
|
|
| def generate_config_id(config: dict) -> str: |
| s = json.dumps(config, sort_keys=True, default=str) |
| return hashlib.md5(s.encode()).hexdigest()[:12] |
|
|
|
|
| def save_jsonl(records: list, filepath: str): |
| os.makedirs(os.path.dirname(filepath), exist_ok=True) |
| with open(filepath, 'a') as f: |
| for rec in records: |
| line = json.dumps(rec, default=_json_default) |
| f.write(line + '\n') |
|
|
|
|
| def load_jsonl(filepath: str) -> list: |
| records = [] |
| with open(filepath) as f: |
| for line in f: |
| line = line.strip() |
| if line: |
| records.append(json.loads(line)) |
| return records |
|
|
|
|
| def _json_default(obj): |
| if isinstance(obj, np.integer): |
| return int(obj) |
| if isinstance(obj, np.floating): |
| return float(obj) |
| if isinstance(obj, np.ndarray): |
| return obj.tolist() |
| if isinstance(obj, np.bool_): |
| return bool(obj) |
| return str(obj) |
|
|
|
|
| def save_sidecar_json(filepath: str, metadata: dict): |
| sidecar_path = filepath.rsplit('.', 1)[0] + '_meta.json' |
| with open(sidecar_path, 'w') as f: |
| json.dump(metadata, f, indent=2, default=_json_default) |
|
|
|
|
| def check_positive(arr, name="array"): |
| if np.any(arr <= 0): |
| raise ValueError(f"{name} contains non-positive values: min={np.min(arr)}") |
| if np.any(np.isnan(arr)): |
| raise ValueError(f"{name} contains NaN values") |
| if np.any(np.isinf(arr)): |
| raise ValueError(f"{name} contains infinite values") |
|
|
|
|
| def relative_param_change(old_params, new_params): |
| """Compute max relative parameter change across all blocks.""" |
| max_change = 0.0 |
| for key in old_params: |
| if key in new_params: |
| old = old_params[key] |
| new = new_params[key] |
| change = np.max(np.abs(new - old) / (1.0 + np.abs(old))) |
| max_change = max(max_change, change) |
| return max_change |
|
|
|
|
| def stable_softmax(logits, axis=-1): |
| """Numerically stable softmax.""" |
| logits = logits - np.max(logits, axis=axis, keepdims=True) |
| e = np.exp(logits) |
| return e / np.sum(e, axis=axis, keepdims=True) |
|
|
|
|
| def ensure_dir(path): |
| os.makedirs(path, exist_ok=True) |
| return path |
|
|