Speaker-ID / app.py
HiMind's picture
initial upload
1c276c3 verified
#app.py
#By, Chance Brownfield
import gradio as gr
import torch
import numpy as np
from pathlib import Path
from MMM import MMM
from Speaker_ID import Speaker_ID
BASE_MODEL_PATH = "models/MMM/mmm.pt" # your manager file
REQUESTED_BASE_MODEL_ID = "unknown" # prefer this if present; will fall back to first model key
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_manager(path: str):
p = Path(path)
if not p.exists():
raise FileNotFoundError(f"MMM file not found: {p}")
mgr = MMM.load(str(p))
if not hasattr(mgr, "models"):
raise RuntimeError("Loaded object is not an MMM manager (missing .models).")
if REQUESTED_BASE_MODEL_ID in mgr.models:
base_id = REQUESTED_BASE_MODEL_ID
else:
keys = list(mgr.models.keys())
if len(keys) == 0:
raise RuntimeError(f"No models found inside manager {path}")
base_id = keys[0]
return mgr, base_id
mgr, BASE_MODEL_ID = load_manager(BASE_MODEL_PATH)
print(f"Loaded MMM manager. Using base model id: {BASE_MODEL_ID}")
speaker_system = Speaker_ID(
mmm_manager=mgr,
base_model_id=BASE_MODEL_ID,
device=DEVICE,
seq_len=1200,
sr=1200,
)
def identify_speaker(spk1_audio, spk2_audio, spk3_audio, query_audio):
"""
Enroll three speakers (files provided as filepaths by Gradio) and identify the query audio.
Returns (predicted_speaker, scores_dict) or (error_message, {}).
"""
try:
enroll_files = {
"Speaker_A": spk1_audio,
"Speaker_B": spk2_audio,
"Speaker_C": spk3_audio,
}
if any(v is None for v in enroll_files.values()):
return "Missing one or more enrollment audio files", {}
if query_audio is None:
return "Missing query audio file", {}
base_only = MMM()
base_only.models[BASE_MODEL_ID] = mgr.models[BASE_MODEL_ID]
speaker_system.mmm = base_only
speaker_system.base_model = base_only.models[BASE_MODEL_ID].to(DEVICE)
speaker_system.base_model.eval()
for speaker_id, audio_path in enroll_files.items():
speaker_system.enroll_speaker(
speaker_id=speaker_id,
audio_input=audio_path,
model_type="gmm",
n_components=4,
epochs=50,
)
best_speaker, best_score, scores = speaker_system.identify(query_audio)
filtered_scores = {k: float(scores.get(k, float("nan"))) for k in enroll_files.keys()}
return best_speaker, filtered_scores
except Exception as e:
import traceback
tb = traceback.format_exc()
msg = f"Error during identify: {e}\n{tb}"
return msg, {}
with gr.Blocks(title="MMM Speaker Identification Demo") as demo:
gr.Markdown(
"""
# 🧠 Multi-Mixture Speaker Identification (MMM)
Enroll three speakers (audio files), then identify a query file.
Uses the trained audio MMM to compute embeddings on-the-fly and store them internally.
"""
)
with gr.Row():
spk1 = gr.Audio(label="Speaker A (enroll)", type="filepath")
spk2 = gr.Audio(label="Speaker B (enroll)", type="filepath")
spk3 = gr.Audio(label="Speaker C (enroll)", type="filepath")
query = gr.Audio(label="Query Audio", type="filepath")
run_btn = gr.Button("Identify Speaker")
output_label = gr.Label(label="Predicted Speaker")
output_scores = gr.JSON(label="Scores (log-likelihoods)")
run_btn.click(
fn=identify_speaker,
inputs=[spk1, spk2, spk3, query],
outputs=[output_label, output_scores],
)
if __name__ == "__main__":
demo.launch()