serliezer commited on
Commit
fbfd974
·
verified ·
1 Parent(s): f6c7dd3

Add src/utils.py

Browse files
Files changed (1) hide show
  1. src/utils.py +120 -0
src/utils.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for the experiment pipeline."""
2
+ import os
3
+ import json
4
+ import time
5
+ import uuid
6
+ import hashlib
7
+ import numpy as np
8
+ from datetime import datetime
9
+ from dataclasses import dataclass, field, asdict
10
+ from typing import Optional, List, Dict, Any, Tuple
11
+
12
+
13
+ @dataclass
14
+ class FitResult:
15
+ """Container for model fitting results."""
16
+ params: Dict[str, np.ndarray]
17
+ objective_trace: List[float] = field(default_factory=list)
18
+ n_iterations: int = 0
19
+ converged: bool = False
20
+ runtime_sec: float = 0.0
21
+ config_id: str = ""
22
+ model_family: str = "poisson_gamma"
23
+ inference_type: str = "vi"
24
+ likelihood: str = "poisson"
25
+ prior: str = "gamma"
26
+ diagnostics: Dict[str, Any] = field(default_factory=dict)
27
+
28
+ def to_dict(self):
29
+ d = {
30
+ 'n_iterations': self.n_iterations,
31
+ 'converged': self.converged,
32
+ 'runtime_sec': self.runtime_sec,
33
+ 'config_id': self.config_id,
34
+ 'model_family': self.model_family,
35
+ 'inference_type': self.inference_type,
36
+ 'likelihood': self.likelihood,
37
+ 'prior': self.prior,
38
+ 'objective_trace_len': len(self.objective_trace),
39
+ 'final_objective': self.objective_trace[-1] if self.objective_trace else None,
40
+ }
41
+ d.update(self.diagnostics)
42
+ return d
43
+
44
+
45
+ def generate_run_id():
46
+ return datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:8]
47
+
48
+
49
+ def generate_config_id(config: dict) -> str:
50
+ s = json.dumps(config, sort_keys=True, default=str)
51
+ return hashlib.md5(s.encode()).hexdigest()[:12]
52
+
53
+
54
+ def save_jsonl(records: list, filepath: str):
55
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
56
+ with open(filepath, 'a') as f:
57
+ for rec in records:
58
+ line = json.dumps(rec, default=_json_default)
59
+ f.write(line + '\n')
60
+
61
+
62
+ def load_jsonl(filepath: str) -> list:
63
+ records = []
64
+ with open(filepath) as f:
65
+ for line in f:
66
+ line = line.strip()
67
+ if line:
68
+ records.append(json.loads(line))
69
+ return records
70
+
71
+
72
+ def _json_default(obj):
73
+ if isinstance(obj, np.integer):
74
+ return int(obj)
75
+ if isinstance(obj, np.floating):
76
+ return float(obj)
77
+ if isinstance(obj, np.ndarray):
78
+ return obj.tolist()
79
+ if isinstance(obj, np.bool_):
80
+ return bool(obj)
81
+ return str(obj)
82
+
83
+
84
+ def save_sidecar_json(filepath: str, metadata: dict):
85
+ sidecar_path = filepath.rsplit('.', 1)[0] + '_meta.json'
86
+ with open(sidecar_path, 'w') as f:
87
+ json.dump(metadata, f, indent=2, default=_json_default)
88
+
89
+
90
+ def check_positive(arr, name="array"):
91
+ if np.any(arr <= 0):
92
+ raise ValueError(f"{name} contains non-positive values: min={np.min(arr)}")
93
+ if np.any(np.isnan(arr)):
94
+ raise ValueError(f"{name} contains NaN values")
95
+ if np.any(np.isinf(arr)):
96
+ raise ValueError(f"{name} contains infinite values")
97
+
98
+
99
+ def relative_param_change(old_params, new_params):
100
+ """Compute max relative parameter change across all blocks."""
101
+ max_change = 0.0
102
+ for key in old_params:
103
+ if key in new_params:
104
+ old = old_params[key]
105
+ new = new_params[key]
106
+ change = np.max(np.abs(new - old) / (1.0 + np.abs(old)))
107
+ max_change = max(max_change, change)
108
+ return max_change
109
+
110
+
111
+ def stable_softmax(logits, axis=-1):
112
+ """Numerically stable softmax."""
113
+ logits = logits - np.max(logits, axis=axis, keepdims=True)
114
+ e = np.exp(logits)
115
+ return e / np.sum(e, axis=axis, keepdims=True)
116
+
117
+
118
+ def ensure_dir(path):
119
+ os.makedirs(path, exist_ok=True)
120
+ return path