Spaces:
Runtime error
Runtime error
| #app.py | |
| #By, Chance Brownfield | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from pathlib import Path | |
| from MMM import MMM | |
| from Speaker_ID import Speaker_ID | |
| BASE_MODEL_PATH = "models/MMM/mmm.pt" # your manager file | |
| REQUESTED_BASE_MODEL_ID = "unknown" # prefer this if present; will fall back to first model key | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_manager(path: str): | |
| p = Path(path) | |
| if not p.exists(): | |
| raise FileNotFoundError(f"MMM file not found: {p}") | |
| mgr = MMM.load(str(p)) | |
| if not hasattr(mgr, "models"): | |
| raise RuntimeError("Loaded object is not an MMM manager (missing .models).") | |
| if REQUESTED_BASE_MODEL_ID in mgr.models: | |
| base_id = REQUESTED_BASE_MODEL_ID | |
| else: | |
| keys = list(mgr.models.keys()) | |
| if len(keys) == 0: | |
| raise RuntimeError(f"No models found inside manager {path}") | |
| base_id = keys[0] | |
| return mgr, base_id | |
| mgr, BASE_MODEL_ID = load_manager(BASE_MODEL_PATH) | |
| print(f"Loaded MMM manager. Using base model id: {BASE_MODEL_ID}") | |
| speaker_system = Speaker_ID( | |
| mmm_manager=mgr, | |
| base_model_id=BASE_MODEL_ID, | |
| device=DEVICE, | |
| seq_len=1200, | |
| sr=1200, | |
| ) | |
| def identify_speaker(spk1_audio, spk2_audio, spk3_audio, query_audio): | |
| """ | |
| Enroll three speakers (files provided as filepaths by Gradio) and identify the query audio. | |
| Returns (predicted_speaker, scores_dict) or (error_message, {}). | |
| """ | |
| try: | |
| enroll_files = { | |
| "Speaker_A": spk1_audio, | |
| "Speaker_B": spk2_audio, | |
| "Speaker_C": spk3_audio, | |
| } | |
| if any(v is None for v in enroll_files.values()): | |
| return "Missing one or more enrollment audio files", {} | |
| if query_audio is None: | |
| return "Missing query audio file", {} | |
| base_only = MMM() | |
| base_only.models[BASE_MODEL_ID] = mgr.models[BASE_MODEL_ID] | |
| speaker_system.mmm = base_only | |
| speaker_system.base_model = base_only.models[BASE_MODEL_ID].to(DEVICE) | |
| speaker_system.base_model.eval() | |
| for speaker_id, audio_path in enroll_files.items(): | |
| speaker_system.enroll_speaker( | |
| speaker_id=speaker_id, | |
| audio_input=audio_path, | |
| model_type="gmm", | |
| n_components=4, | |
| epochs=50, | |
| ) | |
| best_speaker, best_score, scores = speaker_system.identify(query_audio) | |
| filtered_scores = {k: float(scores.get(k, float("nan"))) for k in enroll_files.keys()} | |
| return best_speaker, filtered_scores | |
| except Exception as e: | |
| import traceback | |
| tb = traceback.format_exc() | |
| msg = f"Error during identify: {e}\n{tb}" | |
| return msg, {} | |
| with gr.Blocks(title="MMM Speaker Identification Demo") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🧠 Multi-Mixture Speaker Identification (MMM) | |
| Enroll three speakers (audio files), then identify a query file. | |
| Uses the trained audio MMM to compute embeddings on-the-fly and store them internally. | |
| """ | |
| ) | |
| with gr.Row(): | |
| spk1 = gr.Audio(label="Speaker A (enroll)", type="filepath") | |
| spk2 = gr.Audio(label="Speaker B (enroll)", type="filepath") | |
| spk3 = gr.Audio(label="Speaker C (enroll)", type="filepath") | |
| query = gr.Audio(label="Query Audio", type="filepath") | |
| run_btn = gr.Button("Identify Speaker") | |
| output_label = gr.Label(label="Predicted Speaker") | |
| output_scores = gr.JSON(label="Scores (log-likelihoods)") | |
| run_btn.click( | |
| fn=identify_speaker, | |
| inputs=[spk1, spk2, spk3, query], | |
| outputs=[output_label, output_scores], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |