|
|
|
|
|
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
import librosa
|
|
|
import asyncio
|
|
|
import tempfile
|
|
|
import os
|
|
|
import time
|
|
|
import traceback
|
|
|
from typing import AsyncGenerator, Dict, Any, Optional, Union, Iterable
|
|
|
|
|
|
import speech_recognition as sr
|
|
|
from MMM import MMM
|
|
|
|
|
|
|
|
|
class Speaker_ID:
|
|
|
def __init__(
|
|
|
self,
|
|
|
mmm_manager,
|
|
|
base_model_id: str = "unknown",
|
|
|
device: Union[str, torch.device, None] = None,
|
|
|
seq_len: int = 1200,
|
|
|
sr: int = 1200,
|
|
|
):
|
|
|
self.mmm = mmm_manager
|
|
|
self.base_model_id = base_model_id
|
|
|
self.seq_len = int(seq_len)
|
|
|
self.sr = int(sr)
|
|
|
|
|
|
if device is None:
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
else:
|
|
|
self.device = torch.device(device)
|
|
|
|
|
|
if not hasattr(self.mmm, "models"):
|
|
|
raise ValueError("Provided mmm_manager does not look like an MMM manager (missing .models).")
|
|
|
|
|
|
if self.base_model_id not in self.mmm.models:
|
|
|
available = list(self.mmm.models.keys())
|
|
|
raise KeyError(f"Base model id '{self.base_model_id}' not found. Available keys: {available}")
|
|
|
|
|
|
self.base_model = self.mmm.models[self.base_model_id].to(self.device)
|
|
|
self.base_model.eval()
|
|
|
|
|
|
def _audio_to_tensor(self, wav_path: str) -> torch.Tensor:
|
|
|
y, _ = librosa.load(str(wav_path), sr=self.sr, mono=True)
|
|
|
y = y.astype(np.float32)
|
|
|
if y.size == 0:
|
|
|
raise RuntimeError(f"Empty audio file: {wav_path}")
|
|
|
maxv = float(np.max(np.abs(y)))
|
|
|
if maxv > 0:
|
|
|
y = y / maxv
|
|
|
if y.shape[0] < self.seq_len:
|
|
|
y = np.pad(y, (0, self.seq_len - y.shape[0]))
|
|
|
else:
|
|
|
y = y[: self.seq_len]
|
|
|
return torch.from_numpy(y).unsqueeze(-1)
|
|
|
|
|
|
def _ensure_tensor(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
|
|
if isinstance(features, np.ndarray):
|
|
|
t = torch.from_numpy(features)
|
|
|
elif torch.is_tensor(features):
|
|
|
t = features.clone()
|
|
|
else:
|
|
|
raise TypeError("audio_features must be numpy array or torch tensor or audio file path")
|
|
|
|
|
|
if t.dim() == 1:
|
|
|
t = t.unsqueeze(-1)
|
|
|
if t.dim() == 2:
|
|
|
return t.float()
|
|
|
raise ValueError(f"Unexpected features tensor shape: {t.shape}")
|
|
|
|
|
|
def generate_embedding(self, audio_input: Union[str, np.ndarray, torch.Tensor]) -> np.ndarray:
|
|
|
if isinstance(audio_input, str):
|
|
|
x = self._audio_to_tensor(audio_input)
|
|
|
else:
|
|
|
x = self._ensure_tensor(audio_input)
|
|
|
x = x.to(self.device)
|
|
|
if x.dim() == 2:
|
|
|
x = x.unsqueeze(1)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
out = self.base_model(x)
|
|
|
|
|
|
if isinstance(out, dict):
|
|
|
if "mu" in out:
|
|
|
mu = out["mu"]
|
|
|
emb_bz = mu.mean(dim=0)
|
|
|
emb = emb_bz.squeeze(0).cpu().numpy()
|
|
|
return emb
|
|
|
if "z" in out:
|
|
|
z = out["z"].mean(dim=0).squeeze(0).cpu().numpy()
|
|
|
return z
|
|
|
if "reconstruction" in out:
|
|
|
recon = out["reconstruction"].mean(dim=0).squeeze(0).cpu().numpy()
|
|
|
return recon
|
|
|
|
|
|
if torch.is_tensor(out):
|
|
|
arr = out.mean(dim=0).squeeze(0).cpu().numpy()
|
|
|
return arr
|
|
|
|
|
|
raise KeyError("Base model forward did not return 'mu', 'z', 'reconstruction' or a tensor to use as embedding.")
|
|
|
|
|
|
def enroll_speaker(
|
|
|
self,
|
|
|
speaker_id: str,
|
|
|
audio_input: Union[str, np.ndarray, torch.Tensor],
|
|
|
model_type: str = "mmm",
|
|
|
n_components: int = 4,
|
|
|
epochs: int = 50,
|
|
|
lr: float = 1e-3,
|
|
|
seq_len_for_mmm: int = None,
|
|
|
**fit_kwargs,
|
|
|
) -> str:
|
|
|
model_type = model_type.lower()
|
|
|
if model_type not in ("gmm", "hmm", "mmm"):
|
|
|
raise ValueError("model_type must be 'gmm', 'hmm', or 'mmm'")
|
|
|
|
|
|
emb = self.generate_embedding(audio_input)
|
|
|
if model_type == "gmm":
|
|
|
X = np.asarray(emb, dtype=np.float32)[None, :]
|
|
|
self.mmm.fit_and_add(
|
|
|
data=X,
|
|
|
model_type="gmm",
|
|
|
model_id=speaker_id,
|
|
|
n_components=n_components,
|
|
|
lr=lr,
|
|
|
epochs=epochs,
|
|
|
**fit_kwargs,
|
|
|
)
|
|
|
else:
|
|
|
T = int(seq_len_for_mmm or self.seq_len)
|
|
|
z = torch.tensor(emb, dtype=torch.float32, device=self.device)
|
|
|
seq = z.unsqueeze(0).repeat(T, 1)
|
|
|
seq = seq.unsqueeze(1)
|
|
|
self.mmm.fit_and_add(
|
|
|
data=seq,
|
|
|
model_type="mmm" if model_type == "mmm" else "hmm",
|
|
|
model_id=speaker_id,
|
|
|
input_dim=emb.shape[-1],
|
|
|
output_dim=emb.shape[-1],
|
|
|
hidden_dim=emb.shape[-1] * 2,
|
|
|
z_dim=min(256, emb.shape[-1]),
|
|
|
rnn_hidden=emb.shape[-1],
|
|
|
num_states=fit_kwargs.get("num_states", 8),
|
|
|
n_mix=fit_kwargs.get("n_mix", 2),
|
|
|
trans_d_model=fit_kwargs.get("trans_d_model", 64),
|
|
|
trans_nhead=fit_kwargs.get("trans_nhead", 4),
|
|
|
trans_layers=fit_kwargs.get("trans_layers", 2),
|
|
|
lr=lr,
|
|
|
epochs=epochs,
|
|
|
**fit_kwargs,
|
|
|
)
|
|
|
|
|
|
return speaker_id
|
|
|
|
|
|
def identify(
|
|
|
self,
|
|
|
audio_input: Union[str, np.ndarray, torch.Tensor],
|
|
|
unknown_label_confidence_margin: float = 0.0,
|
|
|
):
|
|
|
emb = self.generate_embedding(audio_input)
|
|
|
emb_np = np.asarray(emb, dtype=np.float32)
|
|
|
X_try = emb_np[None, :]
|
|
|
|
|
|
scores: Dict[str, float] = {}
|
|
|
for model_id in list(self.mmm.models.keys()):
|
|
|
try:
|
|
|
sc = self.mmm.score(model_id, X_try)
|
|
|
if isinstance(sc, dict):
|
|
|
vals = []
|
|
|
for v in sc.values():
|
|
|
try:
|
|
|
vals.append(float(np.asarray(v).mean()))
|
|
|
except Exception:
|
|
|
pass
|
|
|
score_val = float(np.mean(vals)) if vals else float("nan")
|
|
|
else:
|
|
|
try:
|
|
|
score_val = float(np.asarray(sc).mean())
|
|
|
except Exception:
|
|
|
score_val = float(sc)
|
|
|
scores[model_id] = score_val
|
|
|
except Exception:
|
|
|
try:
|
|
|
T = self.seq_len
|
|
|
seq = np.tile(emb_np[None, :], (T, 1, 1))
|
|
|
sc = self.mmm.score(model_id, seq)
|
|
|
try:
|
|
|
scores[model_id] = float(np.asarray(sc).mean())
|
|
|
except Exception:
|
|
|
scores[model_id] = float(sc)
|
|
|
except Exception:
|
|
|
continue
|
|
|
|
|
|
if not scores:
|
|
|
return self.base_model_id, float("nan"), {}
|
|
|
|
|
|
best_model, best_score = max(scores.items(), key=lambda kv: kv[1])
|
|
|
|
|
|
if best_model != self.base_model_id and unknown_label_confidence_margin > 0.0:
|
|
|
unknown_score = scores.get(self.base_model_id, float("-inf"))
|
|
|
if best_score <= unknown_score + unknown_label_confidence_margin:
|
|
|
return self.base_model_id, unknown_score, scores
|
|
|
|
|
|
return best_model, best_score, scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def ASI(
|
|
|
phrase_time_limit: Optional[float] = 3.0,
|
|
|
queue_maxsize: int = 8,
|
|
|
mmm_pt_path: str = "models/MMM/mmm.pt",
|
|
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
|
mgr = MMM.load(mmm_pt_path)
|
|
|
speaker_system = Speaker_ID(mmm_manager=mgr, base_model_id="unknown", seq_len=1200, sr=1200)
|
|
|
|
|
|
loop = asyncio.get_running_loop()
|
|
|
audio_q: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize)
|
|
|
|
|
|
recognizer = sr.Recognizer()
|
|
|
try:
|
|
|
mic = sr.Microphone()
|
|
|
except Exception as e:
|
|
|
raise RuntimeError("Could not open microphone. Check drivers / permissions.") from e
|
|
|
|
|
|
def _bg_callback(recognizer_obj: sr.Recognizer, audio: sr.AudioData) -> None:
|
|
|
try:
|
|
|
wav_bytes = audio.get_wav_data()
|
|
|
try:
|
|
|
loop.call_soon_threadsafe(audio_q.put_nowait, wav_bytes)
|
|
|
except Exception:
|
|
|
pass
|
|
|
except Exception:
|
|
|
traceback.print_exc()
|
|
|
|
|
|
stop_listening = recognizer.listen_in_background(mic, _bg_callback, phrase_time_limit=phrase_time_limit)
|
|
|
|
|
|
try:
|
|
|
while True:
|
|
|
try:
|
|
|
wav_bytes = await audio_q.get()
|
|
|
except asyncio.CancelledError:
|
|
|
break
|
|
|
|
|
|
if wav_bytes is None:
|
|
|
continue
|
|
|
|
|
|
def _write_temp_wav(b: bytes) -> str:
|
|
|
tf = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
|
|
try:
|
|
|
tf.write(b)
|
|
|
tf.flush()
|
|
|
return tf.name
|
|
|
finally:
|
|
|
tf.close()
|
|
|
|
|
|
tmp_path = await loop.run_in_executor(None, _write_temp_wav, wav_bytes)
|
|
|
|
|
|
try:
|
|
|
result = await loop.run_in_executor(None, speaker_system.identify, tmp_path)
|
|
|
best_speaker, best_score, scores = result
|
|
|
yield {
|
|
|
"speaker": best_speaker,
|
|
|
"score": best_score,
|
|
|
"scores": scores,
|
|
|
"path": tmp_path,
|
|
|
"timestamp": time.time(),
|
|
|
}
|
|
|
except Exception as e:
|
|
|
yield {
|
|
|
"error": str(e),
|
|
|
"traceback": traceback.format_exc(),
|
|
|
"path": tmp_path,
|
|
|
"timestamp": time.time(),
|
|
|
}
|
|
|
finally:
|
|
|
try:
|
|
|
os.remove(tmp_path)
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
finally:
|
|
|
try:
|
|
|
stop_listening(wait_for_stop=False)
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
|
|
|
async def _main_cli():
|
|
|
async for res in ASI(phrase_time_limit=3.0):
|
|
|
if "error" in res:
|
|
|
print("ID error:", res["error"])
|
|
|
else:
|
|
|
ts = time.ctime(res["timestamp"])
|
|
|
print(f"[{ts}] Predicted: {res['speaker']} (score={res['score']})")
|
|
|
print("All scores:", res["scores"])
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
asyncio.run(_main_cli())
|
|
|
|