HiMind commited on
Commit
81fa0ce
·
verified ·
1 Parent(s): 34aca5b

Upload 3 files

Browse files
Files changed (3) hide show
  1. ASI.py +303 -0
  2. MMM.py +1077 -0
  3. mmm.pt +3 -0
ASI.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Speaker_ID.py
2
+ # By Chance Brownfield
3
+ import torch
4
+ import numpy as np
5
+ import librosa
6
+ import asyncio
7
+ import tempfile
8
+ import os
9
+ import time
10
+ import traceback
11
+ from typing import AsyncGenerator, Dict, Any, Optional, Union, Iterable
12
+
13
+ import speech_recognition as sr
14
+ from MMM import MMM
15
+
16
+
17
+ class Speaker_ID:
18
+ def __init__(
19
+ self,
20
+ mmm_manager,
21
+ base_model_id: str = "unknown",
22
+ device: Union[str, torch.device, None] = None,
23
+ seq_len: int = 1200,
24
+ sr: int = 1200,
25
+ ):
26
+ self.mmm = mmm_manager
27
+ self.base_model_id = base_model_id
28
+ self.seq_len = int(seq_len)
29
+ self.sr = int(sr)
30
+
31
+ if device is None:
32
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ else:
34
+ self.device = torch.device(device)
35
+
36
+ if not hasattr(self.mmm, "models"):
37
+ raise ValueError("Provided mmm_manager does not look like an MMM manager (missing .models).")
38
+
39
+ if self.base_model_id not in self.mmm.models:
40
+ available = list(self.mmm.models.keys())
41
+ raise KeyError(f"Base model id '{self.base_model_id}' not found. Available keys: {available}")
42
+
43
+ self.base_model = self.mmm.models[self.base_model_id].to(self.device)
44
+ self.base_model.eval()
45
+
46
+ def _audio_to_tensor(self, wav_path: str) -> torch.Tensor:
47
+ y, _ = librosa.load(str(wav_path), sr=self.sr, mono=True)
48
+ y = y.astype(np.float32)
49
+ if y.size == 0:
50
+ raise RuntimeError(f"Empty audio file: {wav_path}")
51
+ maxv = float(np.max(np.abs(y)))
52
+ if maxv > 0:
53
+ y = y / maxv
54
+ if y.shape[0] < self.seq_len:
55
+ y = np.pad(y, (0, self.seq_len - y.shape[0]))
56
+ else:
57
+ y = y[: self.seq_len]
58
+ return torch.from_numpy(y).unsqueeze(-1)
59
+
60
+ def _ensure_tensor(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
61
+ if isinstance(features, np.ndarray):
62
+ t = torch.from_numpy(features)
63
+ elif torch.is_tensor(features):
64
+ t = features.clone()
65
+ else:
66
+ raise TypeError("audio_features must be numpy array or torch tensor or audio file path")
67
+
68
+ if t.dim() == 1:
69
+ t = t.unsqueeze(-1)
70
+ if t.dim() == 2:
71
+ return t.float()
72
+ raise ValueError(f"Unexpected features tensor shape: {t.shape}")
73
+
74
+ def generate_embedding(self, audio_input: Union[str, np.ndarray, torch.Tensor]) -> np.ndarray:
75
+ if isinstance(audio_input, str):
76
+ x = self._audio_to_tensor(audio_input)
77
+ else:
78
+ x = self._ensure_tensor(audio_input)
79
+ x = x.to(self.device)
80
+ if x.dim() == 2:
81
+ x = x.unsqueeze(1)
82
+
83
+ with torch.no_grad():
84
+ out = self.base_model(x)
85
+
86
+ if isinstance(out, dict):
87
+ if "mu" in out:
88
+ mu = out["mu"]
89
+ emb_bz = mu.mean(dim=0)
90
+ emb = emb_bz.squeeze(0).cpu().numpy()
91
+ return emb
92
+ if "z" in out:
93
+ z = out["z"].mean(dim=0).squeeze(0).cpu().numpy()
94
+ return z
95
+ if "reconstruction" in out:
96
+ recon = out["reconstruction"].mean(dim=0).squeeze(0).cpu().numpy()
97
+ return recon
98
+
99
+ if torch.is_tensor(out):
100
+ arr = out.mean(dim=0).squeeze(0).cpu().numpy()
101
+ return arr
102
+
103
+ raise KeyError("Base model forward did not return 'mu', 'z', 'reconstruction' or a tensor to use as embedding.")
104
+
105
+ def enroll_speaker(
106
+ self,
107
+ speaker_id: str,
108
+ audio_input: Union[str, np.ndarray, torch.Tensor],
109
+ model_type: str = "mmm",
110
+ n_components: int = 4,
111
+ epochs: int = 50,
112
+ lr: float = 1e-3,
113
+ seq_len_for_mmm: int = None,
114
+ **fit_kwargs,
115
+ ) -> str:
116
+ model_type = model_type.lower()
117
+ if model_type not in ("gmm", "hmm", "mmm"):
118
+ raise ValueError("model_type must be 'gmm', 'hmm', or 'mmm'")
119
+
120
+ emb = self.generate_embedding(audio_input) # (Z,)
121
+ if model_type == "gmm":
122
+ X = np.asarray(emb, dtype=np.float32)[None, :] # (1, Z)
123
+ self.mmm.fit_and_add(
124
+ data=X,
125
+ model_type="gmm",
126
+ model_id=speaker_id,
127
+ n_components=n_components,
128
+ lr=lr,
129
+ epochs=epochs,
130
+ **fit_kwargs,
131
+ )
132
+ else:
133
+ T = int(seq_len_for_mmm or self.seq_len)
134
+ z = torch.tensor(emb, dtype=torch.float32, device=self.device)
135
+ seq = z.unsqueeze(0).repeat(T, 1)
136
+ seq = seq.unsqueeze(1)
137
+ self.mmm.fit_and_add(
138
+ data=seq,
139
+ model_type="mmm" if model_type == "mmm" else "hmm",
140
+ model_id=speaker_id,
141
+ input_dim=emb.shape[-1],
142
+ output_dim=emb.shape[-1],
143
+ hidden_dim=emb.shape[-1] * 2,
144
+ z_dim=min(256, emb.shape[-1]),
145
+ rnn_hidden=emb.shape[-1],
146
+ num_states=fit_kwargs.get("num_states", 8),
147
+ n_mix=fit_kwargs.get("n_mix", 2),
148
+ trans_d_model=fit_kwargs.get("trans_d_model", 64),
149
+ trans_nhead=fit_kwargs.get("trans_nhead", 4),
150
+ trans_layers=fit_kwargs.get("trans_layers", 2),
151
+ lr=lr,
152
+ epochs=epochs,
153
+ **fit_kwargs,
154
+ )
155
+
156
+ return speaker_id
157
+
158
+ def identify(
159
+ self,
160
+ audio_input: Union[str, np.ndarray, torch.Tensor],
161
+ unknown_label_confidence_margin: float = 0.0,
162
+ ):
163
+ emb = self.generate_embedding(audio_input)
164
+ emb_np = np.asarray(emb, dtype=np.float32)
165
+ X_try = emb_np[None, :]
166
+
167
+ scores: Dict[str, float] = {}
168
+ for model_id in list(self.mmm.models.keys()):
169
+ try:
170
+ sc = self.mmm.score(model_id, X_try)
171
+ if isinstance(sc, dict):
172
+ vals = []
173
+ for v in sc.values():
174
+ try:
175
+ vals.append(float(np.asarray(v).mean()))
176
+ except Exception:
177
+ pass
178
+ score_val = float(np.mean(vals)) if vals else float("nan")
179
+ else:
180
+ try:
181
+ score_val = float(np.asarray(sc).mean())
182
+ except Exception:
183
+ score_val = float(sc)
184
+ scores[model_id] = score_val
185
+ except Exception:
186
+ try:
187
+ T = self.seq_len
188
+ seq = np.tile(emb_np[None, :], (T, 1, 1))
189
+ sc = self.mmm.score(model_id, seq)
190
+ try:
191
+ scores[model_id] = float(np.asarray(sc).mean())
192
+ except Exception:
193
+ scores[model_id] = float(sc)
194
+ except Exception:
195
+ continue
196
+
197
+ if not scores:
198
+ return self.base_model_id, float("nan"), {}
199
+
200
+ best_model, best_score = max(scores.items(), key=lambda kv: kv[1])
201
+
202
+ if best_model != self.base_model_id and unknown_label_confidence_margin > 0.0:
203
+ unknown_score = scores.get(self.base_model_id, float("-inf"))
204
+ if best_score <= unknown_score + unknown_label_confidence_margin:
205
+ return self.base_model_id, unknown_score, scores
206
+
207
+ return best_model, best_score, scores
208
+
209
+
210
+ # -------- Automatic Speaker Identification --------
211
+
212
+ async def ASI(
213
+ phrase_time_limit: Optional[float] = 3.0,
214
+ queue_maxsize: int = 8,
215
+ mmm_pt_path: str = "models/MMM/mmm.pt",
216
+ ) -> AsyncGenerator[Dict[str, Any], None]:
217
+ mgr = MMM.load(mmm_pt_path)
218
+ speaker_system = Speaker_ID(mmm_manager=mgr, base_model_id="unknown", seq_len=1200, sr=1200)
219
+
220
+ loop = asyncio.get_running_loop()
221
+ audio_q: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize)
222
+
223
+ recognizer = sr.Recognizer()
224
+ try:
225
+ mic = sr.Microphone()
226
+ except Exception as e:
227
+ raise RuntimeError("Could not open microphone. Check drivers / permissions.") from e
228
+
229
+ def _bg_callback(recognizer_obj: sr.Recognizer, audio: sr.AudioData) -> None:
230
+ try:
231
+ wav_bytes = audio.get_wav_data()
232
+ try:
233
+ loop.call_soon_threadsafe(audio_q.put_nowait, wav_bytes)
234
+ except Exception:
235
+ pass
236
+ except Exception:
237
+ traceback.print_exc()
238
+
239
+ stop_listening = recognizer.listen_in_background(mic, _bg_callback, phrase_time_limit=phrase_time_limit)
240
+
241
+ try:
242
+ while True:
243
+ try:
244
+ wav_bytes = await audio_q.get()
245
+ except asyncio.CancelledError:
246
+ break
247
+
248
+ if wav_bytes is None:
249
+ continue
250
+
251
+ def _write_temp_wav(b: bytes) -> str:
252
+ tf = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
253
+ try:
254
+ tf.write(b)
255
+ tf.flush()
256
+ return tf.name
257
+ finally:
258
+ tf.close()
259
+
260
+ tmp_path = await loop.run_in_executor(None, _write_temp_wav, wav_bytes)
261
+
262
+ try:
263
+ result = await loop.run_in_executor(None, speaker_system.identify, tmp_path)
264
+ best_speaker, best_score, scores = result
265
+ yield {
266
+ "speaker": best_speaker,
267
+ "score": best_score,
268
+ "scores": scores,
269
+ "path": tmp_path,
270
+ "timestamp": time.time(),
271
+ }
272
+ except Exception as e:
273
+ yield {
274
+ "error": str(e),
275
+ "traceback": traceback.format_exc(),
276
+ "path": tmp_path,
277
+ "timestamp": time.time(),
278
+ }
279
+ finally:
280
+ try:
281
+ os.remove(tmp_path)
282
+ except Exception:
283
+ pass
284
+
285
+ finally:
286
+ try:
287
+ stop_listening(wait_for_stop=False)
288
+ except Exception:
289
+ pass
290
+
291
+
292
+ async def _main_cli():
293
+ async for res in ASI(phrase_time_limit=3.0):
294
+ if "error" in res:
295
+ print("ID error:", res["error"])
296
+ else:
297
+ ts = time.ctime(res["timestamp"])
298
+ print(f"[{ts}] Predicted: {res['speaker']} (score={res['score']})")
299
+ print("All scores:", res["scores"])
300
+
301
+
302
+ if __name__ == "__main__":
303
+ asyncio.run(_main_cli())
MMM.py ADDED
@@ -0,0 +1,1077 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #MMM.py (Multi-Mixture Model)
2
+ #By, Chance Brownfield
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import random
7
+ import string
8
+ import math
9
+ import numpy as np
10
+
11
+ # --- Building Blocks ---
12
+
13
+ class Encoder(nn.Module):
14
+ def __init__(self, input_dim, hidden_dim, z_dim):
15
+ super().__init__()
16
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
17
+ self.fc_mu = nn.Linear(hidden_dim, z_dim)
18
+ self.fc_logvar = nn.Linear(hidden_dim, z_dim)
19
+
20
+ def forward(self, x):
21
+ h = F.relu(self.fc1(x))
22
+ return self.fc_mu(h), self.fc_logvar(h)
23
+
24
+
25
+ class Decoder(nn.Module):
26
+ def __init__(self, z_dim, hidden_dim, output_dim):
27
+ super().__init__()
28
+ self.fc1 = nn.Linear(z_dim, hidden_dim)
29
+ self.fc_out = nn.Linear(hidden_dim, output_dim)
30
+
31
+ def forward(self, z):
32
+ h = F.relu(self.fc1(z))
33
+ return torch.sigmoid(self.fc_out(h))
34
+
35
+
36
+ class RecurrentNetwork(nn.Module):
37
+ def __init__(self, input_dim, hidden_dim, num_states):
38
+ super().__init__()
39
+ self.rnn = nn.LSTM(input_dim, hidden_dim, batch_first=True)
40
+ self.state_emissions = nn.Linear(hidden_dim, num_states)
41
+ self.transition_matrix = nn.Parameter(torch.randn(num_states, num_states))
42
+
43
+ def forward(self, x):
44
+ rnn_out, _ = self.rnn(x)
45
+ emissions = F.log_softmax(self.state_emissions(rnn_out), dim=-1)
46
+ transitions = F.log_softmax(self.transition_matrix, dim=-1)
47
+ return emissions, transitions
48
+
49
+
50
+ class GaussianMixture(nn.Module):
51
+ def __init__(self, n_components, n_features):
52
+ super().__init__()
53
+ self.n_components = n_components
54
+ self.n_features = n_features
55
+ self.logits = nn.Parameter(torch.zeros(n_components))
56
+ self.means = nn.Parameter(torch.randn(n_components, n_features))
57
+ self.log_vars = nn.Parameter(torch.zeros(n_components, n_features))
58
+
59
+ def get_weights(self):
60
+ return F.softmax(self.logits, dim=0)
61
+
62
+ def get_means(self):
63
+ return self.means
64
+
65
+ def get_variances(self):
66
+ return torch.exp(self.log_vars)
67
+
68
+ def log_prob(self, X):
69
+ if not isinstance(X, torch.Tensor):
70
+ X = torch.tensor(X, dtype=self.means.dtype, device=self.means.device)
71
+ else:
72
+ X = X.to(self.means.device).type(self.means.dtype)
73
+ N, D = X.shape
74
+ diff = X.unsqueeze(1) - self.means.unsqueeze(0)
75
+ inv_vars = torch.exp(-self.log_vars)
76
+ exp_term = -0.5 * torch.sum(diff * diff * inv_vars.unsqueeze(0), dim=2)
77
+ log_norm = -0.5 * (torch.sum(self.log_vars, dim=1) + D * math.log(2 * math.pi))
78
+ comp_log_prob = exp_term + log_norm.unsqueeze(0)
79
+ log_weights = F.log_softmax(self.logits, dim=0)
80
+ weighted = comp_log_prob + log_weights.unsqueeze(0)
81
+ return torch.logsumexp(weighted, dim=1)
82
+
83
+ def get_log_likelihoods(self, X):
84
+ if not isinstance(X, torch.Tensor):
85
+ X = torch.tensor(X, dtype=self.means.dtype, device=self.means.device)
86
+ else:
87
+ X = X.to(self.means.device).type(self.means.dtype)
88
+ with torch.no_grad():
89
+ ll = self.log_prob(X)
90
+ return ll.cpu().numpy()
91
+
92
+ def score(self, X):
93
+ ll = self.get_log_likelihoods(X)
94
+ return float(ll.mean())
95
+
96
+
97
+ class HiddenMarkov(nn.Module):
98
+ def __init__(self, n_states, n_mix, n_features):
99
+ super().__init__()
100
+ self.n_states = n_states
101
+ self.n_mix = n_mix
102
+ self.n_features = n_features
103
+ self.pi_logits = nn.Parameter(torch.zeros(n_states))
104
+ self.trans_logits = nn.Parameter(torch.zeros(n_states, n_states))
105
+ self.weight_logits = nn.Parameter(torch.zeros(n_states, n_mix))
106
+ self.means = nn.Parameter(torch.randn(n_states, n_mix, n_features))
107
+ self.log_vars = nn.Parameter(torch.zeros(n_states, n_mix, n_features))
108
+
109
+ def get_initial_prob(self):
110
+ return F.softmax(self.pi_logits, dim=0)
111
+
112
+ def get_transition_matrix(self):
113
+ return F.softmax(self.trans_logits, dim=1)
114
+
115
+ def get_weights(self):
116
+ return F.softmax(self.weight_logits, dim=1)
117
+
118
+ def get_means(self):
119
+ return self.means
120
+
121
+ def get_variances(self):
122
+ return torch.exp(self.log_vars)
123
+
124
+ def log_prob(self, X):
125
+ if not isinstance(X, torch.Tensor):
126
+ X = torch.tensor(X, dtype=self.means.dtype, device=self.means.device)
127
+ else:
128
+ X = X.to(self.means.device).type(self.means.dtype)
129
+ T, D = X.shape
130
+ diff = X.unsqueeze(1).unsqueeze(2) - self.means.unsqueeze(0)
131
+ inv_vars = torch.exp(-self.log_vars)
132
+ exp_term = -0.5 * torch.sum(diff * diff * inv_vars.unsqueeze(0), dim=3)
133
+ log_norm = -0.5 * (torch.sum(self.log_vars, dim=2) + D * math.log(2 * math.pi))
134
+ comp_log_prob = exp_term + log_norm.unsqueeze(0)
135
+ log_mix_weights = F.log_softmax(self.weight_logits, dim=1)
136
+ weighted = comp_log_prob + log_mix_weights.unsqueeze(0)
137
+ emission_log_prob = torch.logsumexp(weighted, dim=2)
138
+ log_pi = F.log_softmax(self.pi_logits, dim=0)
139
+ log_A = F.log_softmax(self.trans_logits, dim=1)
140
+ log_alpha = torch.zeros(T, self.n_states, dtype=X.dtype, device=X.device)
141
+ log_alpha[0] = log_pi + emission_log_prob[0]
142
+ for t in range(1, T):
143
+ prev = log_alpha[t-1].unsqueeze(1)
144
+ log_alpha[t] = emission_log_prob[t] + torch.logsumexp(prev + log_A, dim=1)
145
+ return torch.logsumexp(log_alpha[-1], dim=0)
146
+
147
+ def get_log_likelihoods(self, X):
148
+ if not isinstance(X, torch.Tensor):
149
+ X = torch.tensor(X, dtype=self.means.dtype, device=self.means.device)
150
+ else:
151
+ X = X.to(self.means.device).type(self.means.dtype)
152
+ with torch.no_grad():
153
+ if X.dim() == 3:
154
+ return [self.log_prob(seq).item() for seq in X]
155
+ else:
156
+ return [self.log_prob(X).item()]
157
+
158
+ def score(self, X):
159
+ lls = self.get_log_likelihoods(X)
160
+ return float(sum(lls) / len(lls))
161
+
162
+
163
+ class TimeSeriesTransformer(nn.Module):
164
+ def __init__(self, input_dim, d_model, nhead, num_layers, output_dim, batch_first=True):
165
+ super().__init__()
166
+ self.input_dim = input_dim
167
+ self.d_model = d_model
168
+ self.nhead = nhead
169
+ self.num_encoder_layers = num_layers
170
+ self.output_dim = output_dim
171
+ self.batch_first = batch_first
172
+
173
+ self.input_proj = nn.Linear(input_dim, d_model)
174
+ self.transformer = nn.Transformer(
175
+ d_model=d_model,
176
+ nhead=nhead,
177
+ num_encoder_layers=num_layers,
178
+ num_decoder_layers=num_layers,
179
+ batch_first=batch_first
180
+ )
181
+ self.output_proj = nn.Linear(d_model, output_dim)
182
+
183
+ def forward(self, src, tgt):
184
+ """
185
+ src and tgt shapes depend on batch_first:
186
+ - if batch_first=True: (B, S, input_dim)
187
+ - if batch_first=False: (S, B, input_dim)
188
+ The rest of the model should pass tensors accordingly. We attempt to be permissive:
189
+ """
190
+ src_emb = self.input_proj(src)
191
+ tgt_emb = self.input_proj(tgt) if tgt is not None else None
192
+
193
+ out = self.transformer(src_emb, tgt_emb) if tgt_emb is not None else self.transformer(src_emb, src_emb)
194
+ return self.output_proj(out)
195
+
196
+
197
+
198
+ class VariationalRecurrentMarkovGaussianTransformer(nn.Module):
199
+ """
200
+ Variational Encoder + RNN-HMM + Hidden GMM + Transformer hybrid.
201
+ """
202
+ def __init__(self,
203
+ input_dim,
204
+ hidden_dim,
205
+ z_dim,
206
+ rnn_hidden,
207
+ num_states,
208
+ n_mix,
209
+ trans_d_model,
210
+ trans_nhead,
211
+ trans_layers,
212
+ output_dim):
213
+ super().__init__()
214
+ self.output_dim = output_dim
215
+ self.encoder = Encoder(input_dim, hidden_dim, z_dim)
216
+ self.decoder = Decoder(z_dim, hidden_dim, output_dim)
217
+ self.rn = RecurrentNetwork(z_dim, rnn_hidden, num_states)
218
+ self.hm = HiddenMarkov(num_states, n_mix, z_dim)
219
+ self.transformer = TimeSeriesTransformer(
220
+ input_dim=z_dim,
221
+ d_model=trans_d_model,
222
+ nhead=trans_nhead,
223
+ num_layers=trans_layers,
224
+ output_dim=output_dim
225
+ )
226
+ self.pred_weights = nn.Parameter(torch.ones(z_dim))
227
+ self.recog_weights = nn.Parameter(torch.ones(z_dim))
228
+ self.gen_weights = nn.Parameter(torch.ones(z_dim))
229
+
230
+ def reparameterize(self, mu, logvar):
231
+ std = torch.exp(0.5 * logvar)
232
+ eps = torch.randn_like(std)
233
+ return mu + eps * std
234
+
235
+ def forward(self, x, tgt=None):
236
+ if x.dim() == 3:
237
+ T, B, _ = x.size()
238
+ zs, mus, logvars = [], [], []
239
+ for t in range(T):
240
+ mu_t, logvar_t = self.encoder(x[t])
241
+ z_t = self.reparameterize(mu_t, logvar_t)
242
+ zs.append(z_t)
243
+ mus.append(mu_t)
244
+ logvars.append(logvar_t)
245
+ zs = torch.stack(zs) # (T, B, Z)
246
+ mus = torch.stack(mus) # (T, B, Z)
247
+ logvars = torch.stack(logvars) # (T, B, Z)
248
+ else:
249
+ mu, logvar = self.encoder(x)
250
+ zs = self.reparameterize(mu, logvar)
251
+ if zs.dim() == 1:
252
+ zs = zs.unsqueeze(0).unsqueeze(1) # (1,1,Z)
253
+ mus = mu.unsqueeze(0).unsqueeze(1)
254
+ logvars = logvar.unsqueeze(0).unsqueeze(1)
255
+ elif zs.dim() == 2:
256
+ zs = zs.unsqueeze(1)
257
+ mus = mu.unsqueeze(1)
258
+ logvars = logvar.unsqueeze(1)
259
+ else:
260
+ # already (T,B,Z)
261
+ mus, logvars = mu, logvar
262
+
263
+ T, B, _ = zs.size()
264
+ recon = self.decoder(zs.view(-1, zs.size(-1))).view(T, B, self.output_dim)
265
+ try:
266
+ if x.dim() == 3:
267
+ recon = recon.view_as(x)
268
+ else:
269
+ recon = recon.view_as(x)
270
+ except Exception:
271
+ pass
272
+
273
+ emissions, transitions = self.rn(zs.permute(1, 0, 2)) # emissions shape (B, T, num_states)
274
+
275
+ Tz, Bz, Z = zs.shape
276
+ seq_lls = []
277
+ for b in range(Bz):
278
+ ll_b = self.hm.log_prob(zs[:, b, :]) # should be a scalar tensor (dtype/device consistent)
279
+ if not torch.is_tensor(ll_b):
280
+ ll_b = torch.tensor(ll_b, dtype=zs.dtype, device=zs.device)
281
+ seq_lls.append(ll_b)
282
+ hgmm_ll = torch.stack(seq_lls, dim=0) # (B,)
283
+
284
+ trans_out = self.transformer(zs, tgt) if tgt is not None else None
285
+
286
+ return {
287
+ 'reconstruction': recon,
288
+ 'mu': mus,
289
+ 'logvar': logvars,
290
+ 'emissions': emissions,
291
+ 'transitions': transitions,
292
+ 'hgmm_log_likelihood': hgmm_ll, # shape (B,)
293
+ 'transformer_out': trans_out
294
+ }
295
+
296
+ def loss(self, x, outputs):
297
+ recon, mu, logvar = outputs['reconstruction'], outputs['mu'], outputs['logvar']
298
+ recon_loss = F.mse_loss(recon, x, reduction='sum')
299
+ kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
300
+ hgmm_nll = -torch.sum(outputs['hgmm_log_likelihood'])
301
+ return recon_loss + kld + hgmm_nll
302
+
303
+ def predict(self, x):
304
+ """
305
+ Given x, predict next‐step (or next‐sequence) by:
306
+ 1) encoding to z,
307
+ 2) reweighting latent dims by pred_weights,
308
+ 3) decoding back to input space.
309
+ """
310
+ mu, logvar = self.encoder(x)
311
+ z = self.reparameterize(mu, logvar)
312
+ z_pred = z * torch.sigmoid(self.pred_weights)
313
+ return self.decoder(z_pred)
314
+
315
+ def predict_loss(self, x, target, reward):
316
+ """
317
+ MSE between predict(x) and target,
318
+ weighted by a scalar reward (+/-).
319
+ """
320
+ pred = self.predict(x)
321
+ loss = F.mse_loss(pred, target, reduction='mean')
322
+ return reward * loss
323
+
324
+ def recognize(self, x, tgt_z=None):
325
+ """
326
+ Recognize: map x→z, then transform to tgt_z space via transformer,
327
+ then decode to reconstruct in original space.
328
+ """
329
+ mu, logvar = self.encoder(x)
330
+ z = self.reparameterize(mu, logvar)
331
+ if tgt_z is not None:
332
+ z_in = z.unsqueeze(0)
333
+ tgt = tgt_z.unsqueeze(0)
334
+ z_out = self.transformer(z_in, tgt).squeeze(0)
335
+ else:
336
+ z_out = z
337
+ z_rec = z_out * torch.sigmoid(self.recog_weights)
338
+ return self.decoder(z_rec)
339
+
340
+ def recognition_loss(self, x, target, reward):
341
+ """
342
+ Recon loss between recognize(x) and target,
343
+ weighted by reward.
344
+ """
345
+ rec = self.recognize(x)
346
+ loss = F.mse_loss(rec, target, reduction='mean')
347
+ return reward * loss
348
+
349
+
350
+ def generate(self, num_steps, batch_size=1, z0=None):
351
+ """
352
+ Generate a sequence of length num_steps by:
353
+ 1) sampling initial z from prior (HMM's mixture),
354
+ 2) rolling it through the RNN-HMM to get a latent trajectory,
355
+ 3) reweight by gen_weights and decode each step.
356
+ """
357
+ pi = self.hm.get_initial_prob().detach()
358
+ state = torch.multinomial(pi, num_samples=batch_size, replacement=True)
359
+ z = []
360
+ for t in range(num_steps):
361
+ w = self.hm.get_weights()[state] # (B, n_mix)
362
+ mix_idx = torch.multinomial(w, 1).squeeze(-1)
363
+ mu_t = self.hm.get_means()[state, mix_idx]
364
+ z_t = mu_t * torch.sigmoid(self.gen_weights)
365
+ z.append(z_t)
366
+ A = self.hm.get_transition_matrix()[state]
367
+ state = torch.multinomial(A, 1).squeeze(-1)
368
+ Z = torch.stack(z, dim=0) # (T, B, Z)
369
+ recon = self.decoder(Z.view(-1, Z.size(-1))).view(num_steps, batch_size, -1)
370
+ return recon
371
+
372
+ def generation_loss(self, generated, target_seq, reward):
373
+ """
374
+ Sequence‐level loss between generated and target,
375
+ weighted by reward (+/-).
376
+ """
377
+ loss = F.mse_loss(generated, target_seq, reduction='mean')
378
+ return reward * loss
379
+
380
+
381
+ class MMTransformer(nn.Module):
382
+ """Multi-Mixture Transformrer."""
383
+ def __init__(self, n_components, n_features, model_type='gmm', n_mix=1):
384
+ super().__init__()
385
+ self.model_type = model_type.lower()
386
+ self.n_features = n_features
387
+ self.gmms = []
388
+ self.hgmm_models = {}
389
+ self.active_hmm = None
390
+ if self.model_type == 'gmm':
391
+ self.gmm = GaussianMixture(n_components, n_features)
392
+ elif self.model_type == 'hgmm':
393
+ self.hm = HiddenMarkov(n_components, n_mix, n_features)
394
+ else:
395
+ raise ValueError("model_type must be 'gmm' or 'hgmm'")
396
+
397
+ def _prepare_tensor(self, X):
398
+ return torch.tensor(X, dtype=torch.float32) if not isinstance(X, torch.Tensor) else X.float()
399
+
400
+ def fit(self, X, init_params=None, lr=1e-2, epochs=100, verbose=False, data_id=None):
401
+ if init_params is not None:
402
+ self.import_model(init_params)
403
+
404
+ X_tensor = self._prepare_tensor(X).to(next(self.parameters()).device)
405
+ optimizer = torch.optim.Adam(self.parameters(), lr=lr)
406
+
407
+ for epoch in range(epochs):
408
+ optimizer.zero_grad()
409
+ if self.model_type == 'gmm':
410
+ loss = -torch.mean(self.gmm.log_prob(X_tensor))
411
+ else:
412
+ if X_tensor.dim() == 3:
413
+ loss = -sum(self.hm.log_prob(seq) for seq in X_tensor) / X_tensor.size(0)
414
+ else:
415
+ loss = -self.hm.log_prob(X_tensor)
416
+ loss.backward()
417
+ optimizer.step()
418
+ if verbose and epoch % 10 == 0:
419
+ print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
420
+
421
+ if self.model_type == 'gmm':
422
+ if data_id is None:
423
+ data_id = len(self.gmms)
424
+ while isinstance(data_id, int) and data_id < len(self.gmms) and self.gmms[data_id] is not None:
425
+ data_id += 1
426
+ if data_id == len(self.gmms):
427
+ self.gmms.append(self.gmm)
428
+ else:
429
+ self.gmms[data_id] = self.gmm
430
+ else:
431
+ if data_id is None:
432
+ while True:
433
+ data_id = ''.join(random.choices(string.ascii_lowercase, k=6))
434
+ if data_id not in self.hgmm_models:
435
+ break
436
+ self.hgmm_models[data_id] = self.hm
437
+ self.active_hmm = data_id
438
+
439
+ return data_id
440
+
441
+ def unfit(self, data_id):
442
+ if isinstance(data_id, int):
443
+ if 0 <= data_id < len(self.gmms):
444
+ del self.gmms[data_id]
445
+ else:
446
+ raise ValueError(f"GMM with id {data_id} does not exist.")
447
+ elif isinstance(data_id, str):
448
+ if data_id in self.hgmm_models:
449
+ del self.hgmm_models[data_id]
450
+ if self.active_hmm == data_id:
451
+ self.active_hmm = None
452
+ else:
453
+ raise ValueError(f"HMM model with name '{data_id}' does not exist.")
454
+ else:
455
+ raise TypeError("data_id must be an int (GMM) or str (HMM)")
456
+
457
+ def check_data(self):
458
+ data = {i: 'gmm' for i in range(len(self.gmms))}
459
+ data.update({name: 'hmm' for name in self.hgmm_models.keys()})
460
+ return data
461
+
462
+ def score(self, X):
463
+ with torch.no_grad():
464
+ X_tensor = self._prepare_tensor(X).to(next(self.parameters()).device)
465
+ if self.model_type == 'gmm':
466
+ return float(self.gmm.log_prob(X_tensor).mean().cpu().item())
467
+ else:
468
+ if X_tensor.dim() == 3:
469
+ return float(sum(self.hm.log_prob(seq).item() for seq in X_tensor) / X_tensor.size(0))
470
+ else:
471
+ return float(self.hm.log_prob(X_tensor).cpu().item())
472
+
473
+ def get_log_likelihoods(self, X):
474
+ with torch.no_grad():
475
+ X_tensor = self._prepare_tensor(X).to(next(self.parameters()).device)
476
+ if self.model_type == 'gmm':
477
+ return self.gmm.log_prob(X_tensor).cpu().numpy()
478
+ else:
479
+ if X_tensor.dim() == 3:
480
+ return [self.hm.log_prob(seq).item() for seq in X_tensor]
481
+ else:
482
+ return [self.hm.log_prob(X_tensor).item()]
483
+
484
+ def get_means(self):
485
+ return (self.gmm if self.model_type == 'gmm' else self.hgmm).get_means().cpu().detach().numpy()
486
+
487
+ def get_variances(self):
488
+ return (self.gmm if self.model_type == 'gmm' else self.hgmm).get_variances().cpu().detach().numpy()
489
+
490
+ def get_weights(self):
491
+ return (self.gmm if self.model_type == 'gmm' else self.hgmm).get_weights().cpu().detach().numpy()
492
+
493
+ def export_model(self, filepath=None):
494
+ state = self.state_dict()
495
+ if filepath:
496
+ torch.save(state, filepath)
497
+ return state
498
+
499
+ def import_model(self, source):
500
+ if isinstance(source, str):
501
+ state = torch.load(source)
502
+ elif isinstance(source, dict):
503
+ state = source
504
+ else:
505
+ raise ValueError("Unsupported source for import_model")
506
+ self.load_state_dict(state)
507
+
508
+
509
+ class MMModel(nn.Module):
510
+ """Multi-Mixture Model."""
511
+ def __init__(self):
512
+ super().__init__()
513
+ self.gmms = [] # List of GaussianMixture models
514
+ self.hgmm_models = {} # Dict of HM models keyed by string IDs
515
+ self.active_hmm = None # Optional: active HGMM for scoring/fitting
516
+
517
+ def _generate_unique_id(self):
518
+ while True:
519
+ candidate = ''.join(random.choices(string.ascii_lowercase, k=6))
520
+ if candidate not in self.hgmm_models:
521
+ return candidate
522
+
523
+ def fit(self, data=None, model_type='gmm', n_components=1, n_features=1, n_mix=1,
524
+ data_id=None, init_params=None, lr=1e-2, epochs=100):
525
+ """
526
+ Fit or absorb a model:
527
+ - If `data` is a tensor/array, fit a new model.
528
+ - If `data` is a pre-trained model, absorb it directly.
529
+ - `data_id` determines storage; if None, generate a unique one.
530
+ """
531
+ if model_type == 'gmm':
532
+ if data_id is None:
533
+ data_id = len(self.gmms)
534
+ while data_id < len(self.gmms) and self.gmms[data_id] is not None:
535
+ data_id += 1
536
+ if isinstance(data, GaussianMixture):
537
+ if data_id < len(self.gmms):
538
+ self.gmms[data_id] = data
539
+ else:
540
+ while len(self.gmms) < data_id:
541
+ self.gmms.append(None)
542
+ self.gmms.append(data)
543
+ else:
544
+ model = MMTransformer(n_components, n_features, model_type='gmm')
545
+ model.fit(data, init_params=init_params, lr=lr, epochs=epochs)
546
+ if data_id < len(self.gmms):
547
+ self.gmms[data_id] = model.gmm
548
+ else:
549
+ while len(self.gmms) < data_id:
550
+ self.gmms.append(None)
551
+ self.gmms.append(model.gmm)
552
+ elif model_type == 'hmm':
553
+ if data_id is None:
554
+ data_id = self._generate_unique_id()
555
+ if isinstance(data, HiddenMarkov):
556
+ self.hgmm_models[data_id] = data
557
+ else:
558
+ model = MMTransformer(n_components, n_features, model_type='hmm', n_mix=n_mix)
559
+ model.fit(data, init_params=init_params, lr=lr, epochs=epochs)
560
+ self.hgmm_models[data_id] = model.hm
561
+ else:
562
+ raise ValueError("model_type must be 'gmm' or 'hmm'")
563
+ return data_id
564
+
565
+ def export_model(self, data_id):
566
+ """
567
+ Export the model associated with the data_id.
568
+ Returns a GaussianMixture or HiddenMarkov instance.
569
+ """
570
+ if isinstance(data_id, int):
571
+ if 0 <= data_id < len(self.gmms):
572
+ return self.gmms[data_id]
573
+ else:
574
+ raise ValueError(f"GMM with id {data_id} does not exist.")
575
+ elif isinstance(data_id, str):
576
+ if data_id in self.hgmm_models:
577
+ return self.hgmm_models[data_id]
578
+ else:
579
+ raise ValueError(f"HMM model with name '{data_id}' does not exist.")
580
+ else:
581
+ raise TypeError("data_id must be an int (GMM) or str (HMM)")
582
+
583
+ def unfit(self, data_id):
584
+ """
585
+ Remove a model from the internal storage (GMM or HMM).
586
+ """
587
+ if isinstance(data_id, int):
588
+ if 0 <= data_id < len(self.gmms):
589
+ del self.gmms[data_id]
590
+ else:
591
+ raise ValueError(f"GMM with id {data_id} does not exist.")
592
+ elif isinstance(data_id, str):
593
+ if data_id in self.hgmm_models:
594
+ del self.hgmm_models[data_id]
595
+ if self.active_hmm == data_id:
596
+ self.active_hmm = None
597
+ else:
598
+ raise ValueError(f"HMM model with name '{data_id}' does not exist.")
599
+ else:
600
+ raise TypeError("data_id must be an int (GMM) or str (HMM)")
601
+
602
+ def check_data(self):
603
+ """
604
+ Returns a dict mapping each stored data's ID to its type.
605
+
606
+ - Integer keys → 'gmm'
607
+ - String keys → 'hmm'
608
+ """
609
+ data = {i: 'gmm' for i in range(len(self.gmms)) if self.gmms[i] is not None}
610
+ data.update({name: 'hmm' for name in self.hgmm_models.keys()})
611
+ return data
612
+
613
+ def _all_ids(self):
614
+ return list(self.check_data().keys())
615
+
616
+ def _normalize_ids(self, data_ids):
617
+ if data_ids is None:
618
+ return self._all_ids()
619
+ if isinstance(data_ids, (int, str)):
620
+ return [data_ids]
621
+ return list(data_ids)
622
+
623
+ def _get_submodel(self, data_id):
624
+ if isinstance(data_id, int):
625
+ return self.gmms[data_id]
626
+ return self.hgmm_models[data_id]
627
+
628
+ def get_means(self, data_ids=None):
629
+ """
630
+ If data_ids is None, returns a dict {id: means} for all components;
631
+ if a single id, returns just that component's means (numpy array);
632
+ if a list/tuple, returns a dict.
633
+ """
634
+ ids = self._normalize_ids(data_ids)
635
+ out = {d: self._get_submodel(d).get_means() for d in ids}
636
+ if isinstance(data_ids, (int, str)):
637
+ return out[ids[0]]
638
+ return out
639
+
640
+ def get_variances(self, data_ids=None):
641
+ ids = self._normalize_ids(data_ids)
642
+ out = {d: self._get_submodel(d).get_variances() for d in ids}
643
+ if isinstance(data_ids, (int, str)):
644
+ return out[ids[0]]
645
+ return out
646
+
647
+ def get_weights(self, data_ids=None):
648
+ ids = self._normalize_ids(data_ids)
649
+ out = {d: self._get_submodel(d).get_weights() for d in ids}
650
+ if isinstance(data_ids, (int, str)):
651
+ return out[ids[0]]
652
+ return out
653
+
654
+ def score(self, X, data_ids=None):
655
+ """
656
+ Average log-likelihood(s) of X under each specified component.
657
+ """
658
+ ids = self._normalize_ids(data_ids)
659
+ out = {d: self._get_submodel(d).score(X) for d in ids}
660
+ if isinstance(data_ids, (int, str)):
661
+ return out[ids[0]]
662
+ return out
663
+
664
+ def get_log_likelihoods(self, X, data_ids=None):
665
+ """
666
+ Per-sample log-likelihood(s) of X under each specified component.
667
+ """
668
+ ids = self._normalize_ids(data_ids)
669
+ out = {d: self._get_submodel(d).get_log_likelihoods(X) for d in ids}
670
+ if isinstance(data_ids, (int, str)):
671
+ return out[ids[0]]
672
+ return out
673
+
674
+ class MMM(nn.Module):
675
+ """
676
+ Manager for multiple models: GMM, HMM, and VariationalRecurrentMarkovGaussianTransformer.
677
+ This version uses MSE for reconstruction, gradient clipping, variance clamping, numerical safeguards, and optional annealing.
678
+ """
679
+ def __init__(self):
680
+ super().__init__()
681
+ self.models = nn.ModuleDict()
682
+
683
+ def _generate_unique_id(self, prefix='model'):
684
+ while True:
685
+ candidate = f"{prefix}_{''.join(random.choices(string.ascii_lowercase, k=6))}"
686
+ if candidate not in self.models:
687
+ return candidate
688
+
689
+ def add_model(self, model: nn.Module, model_id: str = None):
690
+ if model_id is None:
691
+ model_id = self._generate_unique_id(model.__class__.__name__)
692
+ if model_id in self.models:
693
+ raise KeyError(f"Model with id '{model_id}' already exists.")
694
+ self.models[model_id] = model
695
+ return model_id
696
+
697
+ def fit_and_add(self,
698
+ data,
699
+ model_type: str = 'gmm',
700
+ model_id: str = None,
701
+ kl_anneal_epochs: int = 0,
702
+ clip_norm: float = 5.0,
703
+ weight_decay: float = 1e-5,
704
+ **kwargs):
705
+ model_type = model_type.lower()
706
+ if model_type in ('gmm','hmm'):
707
+ mm = MMModel()
708
+ mm.fit(data, model_type=model_type, **kwargs)
709
+ model = mm
710
+
711
+ elif model_type == 'mmm':
712
+ # build hybrid model
713
+ model = VariationalRecurrentMarkovGaussianTransformer(
714
+ kwargs.pop('input_dim'),
715
+ kwargs.pop('hidden_dim'),
716
+ kwargs.pop('z_dim'),
717
+ kwargs.pop('rnn_hidden'),
718
+ kwargs.pop('num_states'),
719
+ kwargs.pop('n_mix'),
720
+ kwargs.pop('trans_d_model'),
721
+ kwargs.pop('trans_nhead'),
722
+ kwargs.pop('trans_layers'),
723
+ kwargs.pop('output_dim')
724
+ )
725
+ optim = torch.optim.Adam(model.parameters(), lr=kwargs.get('lr',1e-4), weight_decay=weight_decay)
726
+ epochs = kwargs.get('epochs',100)
727
+ x = data.float().to(next(model.parameters()).device)
728
+
729
+ for epoch in range(epochs):
730
+ model.train()
731
+ optim.zero_grad()
732
+ out = model(x, kwargs.get('tgt', None))
733
+
734
+ recon = out['reconstruction']
735
+ recon_loss = F.mse_loss(recon, x, reduction='sum')
736
+
737
+ mu, logvar = out['mu'], out['logvar']
738
+ logvar_clamped = torch.clamp(logvar, min=-10.0, max=10.0)
739
+ kld = -0.5 * torch.sum(1 + logvar_clamped - mu.pow(2) - logvar_clamped.exp())
740
+
741
+ hgmm_ll = out['hgmm_log_likelihood']
742
+ hgmm_ll = torch.clamp(hgmm_ll, min=-1e6, max=1e6)
743
+ hgmm_nll = -torch.sum(hgmm_ll)
744
+
745
+ kld = torch.nan_to_num(kld, nan=0.0, posinf=1e8, neginf=-1e8)
746
+ hgmm_nll = torch.nan_to_num(hgmm_nll, nan=0.0, posinf=1e8, neginf=-1e8)
747
+
748
+ anneal_w = min(1.0, epoch / kl_anneal_epochs) if kl_anneal_epochs > 0 else 1.0
749
+ loss = recon_loss + anneal_w * (kld + hgmm_nll)
750
+
751
+ loss.backward()
752
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
753
+ optim.step()
754
+
755
+ if epoch % max(1, epochs // 5) == 0:
756
+ print(f"Epoch {epoch}: recon={recon_loss.item():.1f}, kld={kld.item():.1f}, "
757
+ f"hmll={hgmm_nll.item():.1f}, anneal_w={anneal_w:.2f}")
758
+ else:
759
+ raise ValueError("model_type must be 'gmm','hmm', or 'mmm'")
760
+
761
+ assigned_id = self.add_model(model, model_id)
762
+ return assigned_id
763
+
764
+ def export_model(self, model_id: str, filepath: str = None):
765
+ if model_id not in self.models:
766
+ raise KeyError(f"Model '{model_id}' not found.")
767
+ model = self.models[model_id]
768
+ state = model.state_dict()
769
+ if filepath:
770
+ torch.save(state, filepath)
771
+ return state
772
+
773
+ def import_model(self, model_id: str, source):
774
+ if model_id not in self.models:
775
+ raise KeyError(f"Model '{model_id}' not found.")
776
+ model = self.models[model_id]
777
+ if isinstance(source, str):
778
+ state = torch.load(source)
779
+ elif isinstance(source, dict):
780
+ state = source
781
+ else:
782
+ raise ValueError("source must be filepath or state dict")
783
+ model.load_state_dict(state)
784
+
785
+ def _select_data(self, mm, fn, data_ids=None, *args, **kwargs):
786
+ all_keys = list(mm.check_data().keys())
787
+ if data_ids is None:
788
+ ids = all_keys
789
+ elif isinstance(data_ids, (list, tuple)):
790
+ ids = data_ids
791
+ else:
792
+ ids = [data_ids]
793
+ out = {d: fn(mm, d, *args, **kwargs) for d in ids}
794
+ if not isinstance(data_ids, (list, tuple)) and data_ids is not None:
795
+ return out[data_ids]
796
+ return out
797
+
798
+ def get_means(self, model_id: str, data_ids=None):
799
+ mm = self.get_mmm(model_id)
800
+ return self._select_data(
801
+ mm,
802
+ lambda m, d: m._get_submodel(d).get_means(),
803
+ data_ids
804
+ )
805
+
806
+ def get_variances(self, model_id: str, data_ids=None):
807
+ mm = self.get_mmm(model_id)
808
+ return self._select_data(
809
+ mm,
810
+ lambda m, d: m._get_submodel(d).get_variances(),
811
+ data_ids
812
+ )
813
+
814
+ def get_weights(self, model_id: str, data_ids=None):
815
+ mm = self.get_mmm(model_id)
816
+ return self._select_data(
817
+ mm,
818
+ lambda m, d: m._get_submodel(d).get_weights(),
819
+ data_ids
820
+ )
821
+
822
+ def get_log_likelihoods(self, model_id: str, X, data_ids=None):
823
+ mm = self.get_mmm(model_id)
824
+
825
+ def fn(m, d):
826
+ sub = m._get_submodel(d)
827
+ return sub.get_log_likelihoods(X)
828
+
829
+ return self._select_data(mm, fn, data_ids)
830
+
831
+ def score(self, model_id: str, X, data_ids=None):
832
+ mm = self.get_mmm(model_id)
833
+
834
+ def fn(m, d):
835
+ sub = m._get_submodel(d)
836
+ return sub.score(X)
837
+
838
+ return self._select_data(mm, fn, data_ids)
839
+
840
+ def get_mmm(self, model_id: str):
841
+ if model_id not in self.models:
842
+ raise KeyError(f"Model '{model_id}' not found.")
843
+ return self.models[model_id]
844
+
845
+ def save(self, path: str):
846
+ torch.save(self, path)
847
+
848
+ @classmethod
849
+ def load(cls, path: str):
850
+ return torch.load(path, weights_only=False)
851
+
852
+ class WeightedMMM(MMM):
853
+ """
854
+ Enhanced Multi-Mixture Model with weighted predictions and GPU acceleration support.
855
+ Supports training with reward/punishment signals.
856
+ """
857
+ def __init__(self, device=None):
858
+ super().__init__()
859
+ self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
860
+ self.weighted_models = {} # Store models with their weights
861
+ self.reward_signals = {} # Store reward signals for each model
862
+ self.punishment_signals = {} # Store punishment signals for each model
863
+
864
+ def to_device(self, model):
865
+ """Move model to specified device (CPU/GPU)"""
866
+ return model.to(self.device)
867
+
868
+ def fit_with_weights(self,
869
+ data,
870
+ reward_signals,
871
+ punishment_signals,
872
+ model_type='gmm',
873
+ model_id=None,
874
+ reward_weight=1.0,
875
+ punishment_weight=1.0,
876
+ **kwargs):
877
+ """
878
+ Fit model with weighted predictions using reward and punishment signals.
879
+
880
+ Args:
881
+ data: Input sensor data
882
+ reward_signals: Positive reinforcement signals
883
+ punishment_signals: Negative reinforcement signals
884
+ model_type: Type of model ('gmm', 'hmm', or 'mmm')
885
+ model_id: Optional model identifier
886
+ reward_weight: Weight for reward signals
887
+ punishment_weight: Weight for punishment signals
888
+ **kwargs: Additional training parameters
889
+ """
890
+ data = torch.tensor(data, dtype=torch.float32).to(self.device)
891
+ reward_signals = torch.tensor(reward_signals, dtype=torch.float32).to(self.device)
892
+ punishment_signals = torch.tensor(punishment_signals, dtype=torch.float32).to(self.device)
893
+
894
+ baseline_id = self.fit_and_add(data, model_type=model_type, model_id=model_id, **kwargs)
895
+ baseline_model = self.models[baseline_id]
896
+
897
+ weighted_model = self._create_weighted_model(baseline_model, model_type)
898
+ weighted_model = self.to_device(weighted_model)
899
+
900
+ self.reward_signals[baseline_id] = reward_signals
901
+ self.punishment_signals[baseline_id] = punishment_signals
902
+ self.weighted_models[baseline_id] = {
903
+ 'model': weighted_model,
904
+ 'reward_weight': reward_weight,
905
+ 'punishment_weight': punishment_weight
906
+ }
907
+
908
+ self._train_weighted_model(baseline_id, data, reward_signals, punishment_signals, **kwargs)
909
+
910
+ return baseline_id
911
+
912
+ def _create_weighted_model(self, baseline_model, model_type):
913
+ """Create a weighted version of the baseline model"""
914
+ if model_type == 'gmm':
915
+ return GaussianMixture(
916
+ n_components=baseline_model.n_components,
917
+ n_features=baseline_model.n_features
918
+ )
919
+ elif model_type == 'hmm':
920
+ return HiddenMarkov(
921
+ n_states=baseline_model.n_states,
922
+ n_mix=baseline_model.n_mix,
923
+ n_features=baseline_model.n_features
924
+ )
925
+ elif model_type == 'mmm':
926
+ return VariationalRecurrentMarkovGaussianTransformer(
927
+ input_dim=baseline_model.encoder.fc1.in_features,
928
+ hidden_dim=baseline_model.encoder.fc1.out_features,
929
+ z_dim=baseline_model.encoder.fc_mu.out_features,
930
+ rnn_hidden=baseline_model.rn.rnn.hidden_size,
931
+ num_states=baseline_model.rn.state_emissions.out_features,
932
+ n_mix=baseline_model.hm.n_mix,
933
+ trans_d_model=baseline_model.transformer.d_model,
934
+ trans_nhead=baseline_model.transformer.nhead,
935
+ trans_layers=baseline_model.transformer.num_encoder_layers,
936
+ output_dim=baseline_model.transformer.output_proj.out_features
937
+ )
938
+ else:
939
+ raise ValueError(f"Unsupported model type: {model_type}")
940
+
941
+ def _train_weighted_model(self, model_id, data, reward_signals, punishment_signals, **kwargs):
942
+ """Train the weighted model using reward and punishment signals"""
943
+ weighted_info = self.weighted_models[model_id]
944
+ model = weighted_info['model']
945
+ reward_weight = weighted_info['reward_weight']
946
+ punishment_weight = weighted_info['punishment_weight']
947
+
948
+ device = next(model.parameters()).device if any(p.requires_grad for p in model.parameters()) else self.device
949
+ optimizer = torch.optim.Adam(model.parameters(), lr=kwargs.get('lr', 1e-4))
950
+ epochs = kwargs.get('epochs', 100)
951
+
952
+ reward_signals = torch.as_tensor(reward_signals, dtype=torch.float32, device=device).detach()
953
+ punishment_signals = torch.as_tensor(punishment_signals, dtype=torch.float32, device=device).detach()
954
+
955
+ for epoch in range(epochs):
956
+ model.train()
957
+ optimizer.zero_grad()
958
+
959
+ if isinstance(model, (GaussianMixture, HiddenMarkov)):
960
+ log_probs = model.log_prob(data)
961
+ if not torch.is_tensor(log_probs):
962
+ log_probs = torch.as_tensor(log_probs, dtype=torch.float32, device=device)
963
+ else:
964
+ outputs = model(data)
965
+ log_probs = outputs['hgmm_log_likelihood'] # expected shape (B,)
966
+
967
+ if log_probs.dim() > 1:
968
+ log_probs = log_probs.view(log_probs.size(0), -1).mean(dim=1)
969
+ log_probs = log_probs.to(device).type(torch.float32)
970
+
971
+ N = log_probs.numel()
972
+ def _broadcast_signal(sig):
973
+ if sig.numel() == 1:
974
+ return sig.expand(N)
975
+ if sig.numel() == N:
976
+ return sig.view(N)
977
+ try:
978
+ return sig.expand(N)
979
+ except Exception:
980
+ raise ValueError(f"Signal of length {sig.numel()} cannot be broadcast to {N} samples")
981
+
982
+ r = _broadcast_signal(reward_signals)
983
+ p = _broadcast_signal(punishment_signals)
984
+
985
+ reward_loss = -torch.mean(log_probs * r) * reward_weight
986
+ punishment_loss = torch.mean(log_probs * p) * punishment_weight
987
+ total_loss = reward_loss + punishment_loss
988
+
989
+ if not torch.isfinite(total_loss):
990
+ print("Warning: non-finite total_loss detected; skipping update and reducing LR.")
991
+ for g in optimizer.param_groups:
992
+ g['lr'] = max(1e-8, g['lr'] * 0.1)
993
+ continue
994
+
995
+ total_loss.backward()
996
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
997
+ optimizer.step()
998
+
999
+ if epoch % max(1, epochs // 5) == 0:
1000
+ print(f"Epoch {epoch}: reward_loss={reward_loss.item():.6f}, punishment_loss={punishment_loss.item():.6f}")
1001
+
1002
+ def predict_anomalies(self, data, model_id, threshold=0.95):
1003
+ """
1004
+ Predict anomalies using both baseline and weighted models.
1005
+
1006
+ Args:
1007
+ data: Input sensor data
1008
+ model_id: Model identifier
1009
+ threshold: Anomaly detection threshold
1010
+
1011
+ Returns:
1012
+ dict containing:
1013
+ - baseline_predictions: Anomaly predictions from baseline model
1014
+ - weighted_predictions: Anomaly predictions from weighted model
1015
+ - confidence_scores: Confidence scores for predictions
1016
+ """
1017
+ data = torch.tensor(data, dtype=torch.float32).to(self.device)
1018
+
1019
+ baseline_model = self.models[model_id]
1020
+ baseline_log_probs = baseline_model.log_prob(data)
1021
+ baseline_predictions = (baseline_log_probs < threshold).cpu().numpy()
1022
+
1023
+ weighted_model = self.weighted_models[model_id]['model']
1024
+ weighted_log_probs = weighted_model.log_prob(data)
1025
+ weighted_predictions = (weighted_log_probs < threshold).cpu().numpy()
1026
+
1027
+ confidence_scores = {
1028
+ 'baseline': torch.sigmoid(baseline_log_probs).cpu().numpy(),
1029
+ 'weighted': torch.sigmoid(weighted_log_probs).cpu().numpy()
1030
+ }
1031
+
1032
+ return {
1033
+ 'baseline_predictions': baseline_predictions,
1034
+ 'weighted_predictions': weighted_predictions,
1035
+ 'confidence_scores': confidence_scores
1036
+ }
1037
+
1038
+ def evaluate_models(self, test_data, test_labels, model_id):
1039
+ """
1040
+ Evaluate and compare baseline and weighted models.
1041
+
1042
+ Args:
1043
+ test_data: Test sensor data
1044
+ test_labels: Ground truth labels
1045
+ model_id: Model identifier
1046
+
1047
+ Returns:
1048
+ dict containing evaluation metrics for both models
1049
+ """
1050
+ predictions = self.predict_anomalies(test_data, model_id)
1051
+
1052
+ def calculate_metrics(preds, labels):
1053
+ tp = np.sum((preds == 1) & (labels == 1))
1054
+ fp = np.sum((preds == 1) & (labels == 0))
1055
+ tn = np.sum((preds == 0) & (labels == 0))
1056
+ fn = np.sum((preds == 0) & (labels == 1))
1057
+
1058
+ accuracy = (tp + tn) / (tp + tn + fp + fn)
1059
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
1060
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
1061
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
1062
+
1063
+ return {
1064
+ 'accuracy': accuracy,
1065
+ 'precision': precision,
1066
+ 'recall': recall,
1067
+ 'f1_score': f1,
1068
+ 'false_alarm_rate': fp / (fp + tn) if (fp + tn) > 0 else 0
1069
+ }
1070
+
1071
+ baseline_metrics = calculate_metrics(predictions['baseline_predictions'], test_labels)
1072
+ weighted_metrics = calculate_metrics(predictions['weighted_predictions'], test_labels)
1073
+
1074
+ return {
1075
+ 'baseline': baseline_metrics,
1076
+ 'weighted': weighted_metrics
1077
+ }
mmm.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6defbe66414e745bab36dde7cb0684e46f25969429daa5e66780c7fb21c173fd
3
+ size 5222802