# Speaker_ID.py # By Chance Brownfield import torch import numpy as np import librosa import asyncio import tempfile import os import time import traceback from typing import AsyncGenerator, Dict, Any, Optional, Union, Iterable import speech_recognition as sr from MMM import MMM class Speaker_ID: def __init__( self, mmm_manager, base_model_id: str = "unknown", device: Union[str, torch.device, None] = None, seq_len: int = 1200, sr: int = 1200, ): self.mmm = mmm_manager self.base_model_id = base_model_id self.seq_len = int(seq_len) self.sr = int(sr) if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) if not hasattr(self.mmm, "models"): raise ValueError("Provided mmm_manager does not look like an MMM manager (missing .models).") if self.base_model_id not in self.mmm.models: available = list(self.mmm.models.keys()) raise KeyError(f"Base model id '{self.base_model_id}' not found. Available keys: {available}") self.base_model = self.mmm.models[self.base_model_id].to(self.device) self.base_model.eval() def _audio_to_tensor(self, wav_path: str) -> torch.Tensor: y, _ = librosa.load(str(wav_path), sr=self.sr, mono=True) y = y.astype(np.float32) if y.size == 0: raise RuntimeError(f"Empty audio file: {wav_path}") maxv = float(np.max(np.abs(y))) if maxv > 0: y = y / maxv if y.shape[0] < self.seq_len: y = np.pad(y, (0, self.seq_len - y.shape[0])) else: y = y[: self.seq_len] return torch.from_numpy(y).unsqueeze(-1) def _ensure_tensor(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: if isinstance(features, np.ndarray): t = torch.from_numpy(features) elif torch.is_tensor(features): t = features.clone() else: raise TypeError("audio_features must be numpy array or torch tensor or audio file path") if t.dim() == 1: t = t.unsqueeze(-1) if t.dim() == 2: return t.float() raise ValueError(f"Unexpected features tensor shape: {t.shape}") def generate_embedding(self, audio_input: Union[str, np.ndarray, torch.Tensor]) -> np.ndarray: if isinstance(audio_input, str): x = self._audio_to_tensor(audio_input) else: x = self._ensure_tensor(audio_input) x = x.to(self.device) if x.dim() == 2: x = x.unsqueeze(1) with torch.no_grad(): out = self.base_model(x) if isinstance(out, dict): if "mu" in out: mu = out["mu"] emb_bz = mu.mean(dim=0) emb = emb_bz.squeeze(0).cpu().numpy() return emb if "z" in out: z = out["z"].mean(dim=0).squeeze(0).cpu().numpy() return z if "reconstruction" in out: recon = out["reconstruction"].mean(dim=0).squeeze(0).cpu().numpy() return recon if torch.is_tensor(out): arr = out.mean(dim=0).squeeze(0).cpu().numpy() return arr raise KeyError("Base model forward did not return 'mu', 'z', 'reconstruction' or a tensor to use as embedding.") def enroll_speaker( self, speaker_id: str, audio_input: Union[str, np.ndarray, torch.Tensor], model_type: str = "mmm", n_components: int = 4, epochs: int = 50, lr: float = 1e-3, seq_len_for_mmm: int = None, **fit_kwargs, ) -> str: model_type = model_type.lower() if model_type not in ("gmm", "hmm", "mmm"): raise ValueError("model_type must be 'gmm', 'hmm', or 'mmm'") emb = self.generate_embedding(audio_input) # (Z,) if model_type == "gmm": X = np.asarray(emb, dtype=np.float32)[None, :] # (1, Z) self.mmm.fit_and_add( data=X, model_type="gmm", model_id=speaker_id, n_components=n_components, lr=lr, epochs=epochs, **fit_kwargs, ) else: T = int(seq_len_for_mmm or self.seq_len) z = torch.tensor(emb, dtype=torch.float32, device=self.device) seq = z.unsqueeze(0).repeat(T, 1) seq = seq.unsqueeze(1) self.mmm.fit_and_add( data=seq, model_type="mmm" if model_type == "mmm" else "hmm", model_id=speaker_id, input_dim=emb.shape[-1], output_dim=emb.shape[-1], hidden_dim=emb.shape[-1] * 2, z_dim=min(256, emb.shape[-1]), rnn_hidden=emb.shape[-1], num_states=fit_kwargs.get("num_states", 8), n_mix=fit_kwargs.get("n_mix", 2), trans_d_model=fit_kwargs.get("trans_d_model", 64), trans_nhead=fit_kwargs.get("trans_nhead", 4), trans_layers=fit_kwargs.get("trans_layers", 2), lr=lr, epochs=epochs, **fit_kwargs, ) return speaker_id def identify( self, audio_input: Union[str, np.ndarray, torch.Tensor], unknown_label_confidence_margin: float = 0.0, ): emb = self.generate_embedding(audio_input) emb_np = np.asarray(emb, dtype=np.float32) X_try = emb_np[None, :] scores: Dict[str, float] = {} for model_id in list(self.mmm.models.keys()): try: sc = self.mmm.score(model_id, X_try) if isinstance(sc, dict): vals = [] for v in sc.values(): try: vals.append(float(np.asarray(v).mean())) except Exception: pass score_val = float(np.mean(vals)) if vals else float("nan") else: try: score_val = float(np.asarray(sc).mean()) except Exception: score_val = float(sc) scores[model_id] = score_val except Exception: try: T = self.seq_len seq = np.tile(emb_np[None, :], (T, 1, 1)) sc = self.mmm.score(model_id, seq) try: scores[model_id] = float(np.asarray(sc).mean()) except Exception: scores[model_id] = float(sc) except Exception: continue if not scores: return self.base_model_id, float("nan"), {} best_model, best_score = max(scores.items(), key=lambda kv: kv[1]) if best_model != self.base_model_id and unknown_label_confidence_margin > 0.0: unknown_score = scores.get(self.base_model_id, float("-inf")) if best_score <= unknown_score + unknown_label_confidence_margin: return self.base_model_id, unknown_score, scores return best_model, best_score, scores # -------- Automatic Speaker Identification -------- async def ASI( phrase_time_limit: Optional[float] = 3.0, queue_maxsize: int = 8, mmm_pt_path: str = "models/MMM/mmm.pt", ) -> AsyncGenerator[Dict[str, Any], None]: mgr = MMM.load(mmm_pt_path) speaker_system = Speaker_ID(mmm_manager=mgr, base_model_id="unknown", seq_len=1200, sr=1200) loop = asyncio.get_running_loop() audio_q: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize) recognizer = sr.Recognizer() try: mic = sr.Microphone() except Exception as e: raise RuntimeError("Could not open microphone. Check drivers / permissions.") from e def _bg_callback(recognizer_obj: sr.Recognizer, audio: sr.AudioData) -> None: try: wav_bytes = audio.get_wav_data() try: loop.call_soon_threadsafe(audio_q.put_nowait, wav_bytes) except Exception: pass except Exception: traceback.print_exc() stop_listening = recognizer.listen_in_background(mic, _bg_callback, phrase_time_limit=phrase_time_limit) try: while True: try: wav_bytes = await audio_q.get() except asyncio.CancelledError: break if wav_bytes is None: continue def _write_temp_wav(b: bytes) -> str: tf = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) try: tf.write(b) tf.flush() return tf.name finally: tf.close() tmp_path = await loop.run_in_executor(None, _write_temp_wav, wav_bytes) try: result = await loop.run_in_executor(None, speaker_system.identify, tmp_path) best_speaker, best_score, scores = result yield { "speaker": best_speaker, "score": best_score, "scores": scores, "path": tmp_path, "timestamp": time.time(), } except Exception as e: yield { "error": str(e), "traceback": traceback.format_exc(), "path": tmp_path, "timestamp": time.time(), } finally: try: os.remove(tmp_path) except Exception: pass finally: try: stop_listening(wait_for_stop=False) except Exception: pass async def _main_cli(): async for res in ASI(phrase_time_limit=3.0): if "error" in res: print("ID error:", res["error"]) else: ts = time.ctime(res["timestamp"]) print(f"[{ts}] Predicted: {res['speaker']} (score={res['score']})") print("All scores:", res["scores"]) if __name__ == "__main__": asyncio.run(_main_cli())