HiMind's picture
Upload 3 files
81fa0ce verified
# Speaker_ID.py
# By Chance Brownfield
import torch
import numpy as np
import librosa
import asyncio
import tempfile
import os
import time
import traceback
from typing import AsyncGenerator, Dict, Any, Optional, Union, Iterable
import speech_recognition as sr
from MMM import MMM
class Speaker_ID:
def __init__(
self,
mmm_manager,
base_model_id: str = "unknown",
device: Union[str, torch.device, None] = None,
seq_len: int = 1200,
sr: int = 1200,
):
self.mmm = mmm_manager
self.base_model_id = base_model_id
self.seq_len = int(seq_len)
self.sr = int(sr)
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
if not hasattr(self.mmm, "models"):
raise ValueError("Provided mmm_manager does not look like an MMM manager (missing .models).")
if self.base_model_id not in self.mmm.models:
available = list(self.mmm.models.keys())
raise KeyError(f"Base model id '{self.base_model_id}' not found. Available keys: {available}")
self.base_model = self.mmm.models[self.base_model_id].to(self.device)
self.base_model.eval()
def _audio_to_tensor(self, wav_path: str) -> torch.Tensor:
y, _ = librosa.load(str(wav_path), sr=self.sr, mono=True)
y = y.astype(np.float32)
if y.size == 0:
raise RuntimeError(f"Empty audio file: {wav_path}")
maxv = float(np.max(np.abs(y)))
if maxv > 0:
y = y / maxv
if y.shape[0] < self.seq_len:
y = np.pad(y, (0, self.seq_len - y.shape[0]))
else:
y = y[: self.seq_len]
return torch.from_numpy(y).unsqueeze(-1)
def _ensure_tensor(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
if isinstance(features, np.ndarray):
t = torch.from_numpy(features)
elif torch.is_tensor(features):
t = features.clone()
else:
raise TypeError("audio_features must be numpy array or torch tensor or audio file path")
if t.dim() == 1:
t = t.unsqueeze(-1)
if t.dim() == 2:
return t.float()
raise ValueError(f"Unexpected features tensor shape: {t.shape}")
def generate_embedding(self, audio_input: Union[str, np.ndarray, torch.Tensor]) -> np.ndarray:
if isinstance(audio_input, str):
x = self._audio_to_tensor(audio_input)
else:
x = self._ensure_tensor(audio_input)
x = x.to(self.device)
if x.dim() == 2:
x = x.unsqueeze(1)
with torch.no_grad():
out = self.base_model(x)
if isinstance(out, dict):
if "mu" in out:
mu = out["mu"]
emb_bz = mu.mean(dim=0)
emb = emb_bz.squeeze(0).cpu().numpy()
return emb
if "z" in out:
z = out["z"].mean(dim=0).squeeze(0).cpu().numpy()
return z
if "reconstruction" in out:
recon = out["reconstruction"].mean(dim=0).squeeze(0).cpu().numpy()
return recon
if torch.is_tensor(out):
arr = out.mean(dim=0).squeeze(0).cpu().numpy()
return arr
raise KeyError("Base model forward did not return 'mu', 'z', 'reconstruction' or a tensor to use as embedding.")
def enroll_speaker(
self,
speaker_id: str,
audio_input: Union[str, np.ndarray, torch.Tensor],
model_type: str = "mmm",
n_components: int = 4,
epochs: int = 50,
lr: float = 1e-3,
seq_len_for_mmm: int = None,
**fit_kwargs,
) -> str:
model_type = model_type.lower()
if model_type not in ("gmm", "hmm", "mmm"):
raise ValueError("model_type must be 'gmm', 'hmm', or 'mmm'")
emb = self.generate_embedding(audio_input) # (Z,)
if model_type == "gmm":
X = np.asarray(emb, dtype=np.float32)[None, :] # (1, Z)
self.mmm.fit_and_add(
data=X,
model_type="gmm",
model_id=speaker_id,
n_components=n_components,
lr=lr,
epochs=epochs,
**fit_kwargs,
)
else:
T = int(seq_len_for_mmm or self.seq_len)
z = torch.tensor(emb, dtype=torch.float32, device=self.device)
seq = z.unsqueeze(0).repeat(T, 1)
seq = seq.unsqueeze(1)
self.mmm.fit_and_add(
data=seq,
model_type="mmm" if model_type == "mmm" else "hmm",
model_id=speaker_id,
input_dim=emb.shape[-1],
output_dim=emb.shape[-1],
hidden_dim=emb.shape[-1] * 2,
z_dim=min(256, emb.shape[-1]),
rnn_hidden=emb.shape[-1],
num_states=fit_kwargs.get("num_states", 8),
n_mix=fit_kwargs.get("n_mix", 2),
trans_d_model=fit_kwargs.get("trans_d_model", 64),
trans_nhead=fit_kwargs.get("trans_nhead", 4),
trans_layers=fit_kwargs.get("trans_layers", 2),
lr=lr,
epochs=epochs,
**fit_kwargs,
)
return speaker_id
def identify(
self,
audio_input: Union[str, np.ndarray, torch.Tensor],
unknown_label_confidence_margin: float = 0.0,
):
emb = self.generate_embedding(audio_input)
emb_np = np.asarray(emb, dtype=np.float32)
X_try = emb_np[None, :]
scores: Dict[str, float] = {}
for model_id in list(self.mmm.models.keys()):
try:
sc = self.mmm.score(model_id, X_try)
if isinstance(sc, dict):
vals = []
for v in sc.values():
try:
vals.append(float(np.asarray(v).mean()))
except Exception:
pass
score_val = float(np.mean(vals)) if vals else float("nan")
else:
try:
score_val = float(np.asarray(sc).mean())
except Exception:
score_val = float(sc)
scores[model_id] = score_val
except Exception:
try:
T = self.seq_len
seq = np.tile(emb_np[None, :], (T, 1, 1))
sc = self.mmm.score(model_id, seq)
try:
scores[model_id] = float(np.asarray(sc).mean())
except Exception:
scores[model_id] = float(sc)
except Exception:
continue
if not scores:
return self.base_model_id, float("nan"), {}
best_model, best_score = max(scores.items(), key=lambda kv: kv[1])
if best_model != self.base_model_id and unknown_label_confidence_margin > 0.0:
unknown_score = scores.get(self.base_model_id, float("-inf"))
if best_score <= unknown_score + unknown_label_confidence_margin:
return self.base_model_id, unknown_score, scores
return best_model, best_score, scores
# -------- Automatic Speaker Identification --------
async def ASI(
phrase_time_limit: Optional[float] = 3.0,
queue_maxsize: int = 8,
mmm_pt_path: str = "models/MMM/mmm.pt",
) -> AsyncGenerator[Dict[str, Any], None]:
mgr = MMM.load(mmm_pt_path)
speaker_system = Speaker_ID(mmm_manager=mgr, base_model_id="unknown", seq_len=1200, sr=1200)
loop = asyncio.get_running_loop()
audio_q: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize)
recognizer = sr.Recognizer()
try:
mic = sr.Microphone()
except Exception as e:
raise RuntimeError("Could not open microphone. Check drivers / permissions.") from e
def _bg_callback(recognizer_obj: sr.Recognizer, audio: sr.AudioData) -> None:
try:
wav_bytes = audio.get_wav_data()
try:
loop.call_soon_threadsafe(audio_q.put_nowait, wav_bytes)
except Exception:
pass
except Exception:
traceback.print_exc()
stop_listening = recognizer.listen_in_background(mic, _bg_callback, phrase_time_limit=phrase_time_limit)
try:
while True:
try:
wav_bytes = await audio_q.get()
except asyncio.CancelledError:
break
if wav_bytes is None:
continue
def _write_temp_wav(b: bytes) -> str:
tf = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
try:
tf.write(b)
tf.flush()
return tf.name
finally:
tf.close()
tmp_path = await loop.run_in_executor(None, _write_temp_wav, wav_bytes)
try:
result = await loop.run_in_executor(None, speaker_system.identify, tmp_path)
best_speaker, best_score, scores = result
yield {
"speaker": best_speaker,
"score": best_score,
"scores": scores,
"path": tmp_path,
"timestamp": time.time(),
}
except Exception as e:
yield {
"error": str(e),
"traceback": traceback.format_exc(),
"path": tmp_path,
"timestamp": time.time(),
}
finally:
try:
os.remove(tmp_path)
except Exception:
pass
finally:
try:
stop_listening(wait_for_stop=False)
except Exception:
pass
async def _main_cli():
async for res in ASI(phrase_time_limit=3.0):
if "error" in res:
print("ID error:", res["error"])
else:
ts = time.ctime(res["timestamp"])
print(f"[{ts}] Predicted: {res['speaker']} (score={res['score']})")
print("All scores:", res["scores"])
if __name__ == "__main__":
asyncio.run(_main_cli())