""" Vietnamese Speaker Profiling - Multi-Model Gradio Interface Supports: Vietnamese Wav2Vec2 and PhoWhisper encoders """ import os import torch import librosa import numpy as np import gradio as gr from pathlib import Path from safetensors.torch import load_file as load_safetensors # Model configurations MODELS_CONFIG = { "Wav2Vec2 Vietnamese": { "path": "model/vulehuubinh", "encoder_name": "nguyenvulebinh/wav2vec2-base-vi-vlsp2020", "is_whisper": False, "description": "Vietnamese Wav2Vec2 pretrained model - Fast inference" }, "PhoWhisper": { "path": "model/pho", "encoder_name": "vinai/PhoWhisper-base", "is_whisper": True, "description": "Vietnamese Whisper model - Higher accuracy" } } # Labels - IMPORTANT: Must match training order! # Model was trained with Female=0, Male=1 GENDER_LABELS = { 0: "Female", 1: "Male" } DIALECT_LABELS = { 0: "North", 1: "Central", 2: "South" } class MultiModelProfiler: """Speaker Profiler supporting multiple encoder models.""" def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.sampling_rate = 16000 self.max_duration = 5 # seconds for non-whisper models self.models = {} self.processors = {} self.current_model = None print(f"Using device: {self.device}") # Pre-load all models self._load_all_models() def _load_all_models(self): """Load all available models.""" for model_name, config in MODELS_CONFIG.items(): model_path = Path(config["path"]) if model_path.exists(): print(f"Loading {model_name}...") self._load_single_model(model_name, config) else: print(f"Model not found: {model_path}") def _load_single_model(self, model_name: str, config: dict): """Load a specific model.""" try: model_path = Path(config["path"]) is_whisper = config["is_whisper"] encoder_name = config["encoder_name"] # Load processor if is_whisper: from transformers import WhisperFeatureExtractor processor = WhisperFeatureExtractor.from_pretrained(encoder_name) else: from transformers import Wav2Vec2FeatureExtractor processor = Wav2Vec2FeatureExtractor.from_pretrained(encoder_name) # Load model - use MultiTaskSpeakerModel from src.models import MultiTaskSpeakerModel # Load checkpoint first to detect head_hidden_dim checkpoint_path = model_path / "model.safetensors" pt_path = model_path / "best_model.pt" state_dict = None if checkpoint_path.exists(): state_dict = load_safetensors(str(checkpoint_path)) elif pt_path.exists(): checkpoint = torch.load(pt_path, map_location=self.device, weights_only=False) if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] else: state_dict = checkpoint # Auto-detect head_hidden_dim from checkpoint head_hidden_dim = 256 # default if state_dict is not None and "gender_head.0.weight" in state_dict: # gender_head.0.weight has shape [head_hidden_dim, hidden_size] head_hidden_dim = state_dict["gender_head.0.weight"].shape[0] print(f"Detected head_hidden_dim: {head_hidden_dim}") model = MultiTaskSpeakerModel( model_name=encoder_name, num_genders=2, num_dialects=3, dropout=0.1, head_hidden_dim=head_hidden_dim, freeze_encoder=True ) # Load checkpoint weights if state_dict is not None: model.load_state_dict(state_dict) print(f"Loaded checkpoint: {checkpoint_path if checkpoint_path.exists() else pt_path}") model.to(self.device) model.eval() self.models[model_name] = model self.processors[model_name] = processor if self.current_model is None: self.current_model = model_name print(f"{model_name} loaded successfully") except Exception as e: print(f"Error loading {model_name}: {e}") import traceback traceback.print_exc() def predict(self, audio_path: str, model_name: str): """Predict gender and dialect from audio.""" if model_name not in self.models: available = list(self.models.keys()) if not available: return "No models available", "No models available" model_name = available[0] try: model = self.models[model_name] processor = self.processors[model_name] is_whisper = MODELS_CONFIG[model_name]["is_whisper"] # Set max duration based on model type if is_whisper: max_duration = 30 # Whisper requires 30 seconds else: max_duration = self.max_duration # Load audio using librosa waveform, sr = librosa.load(audio_path, sr=self.sampling_rate, mono=True) # Trim to max duration max_samples = int(max_duration * self.sampling_rate) if len(waveform) > max_samples: waveform = waveform[:max_samples] # Process based on model type if is_whisper: # Whisper requires exactly 30 seconds - pad if needed whisper_length = self.sampling_rate * 30 if len(waveform) < whisper_length: waveform = np.pad(waveform, (0, whisper_length - len(waveform))) inputs = processor( waveform, sampling_rate=self.sampling_rate, return_tensors="pt" ) input_tensor = inputs.input_features.to(self.device) else: # Wav2Vec2 uses raw waveform inputs = processor( waveform, sampling_rate=self.sampling_rate, return_tensors="pt", padding=True ) input_tensor = inputs.input_values.to(self.device) # Inference with torch.no_grad(): outputs = model(input_tensor) gender_logits = outputs['gender_logits'] dialect_logits = outputs['dialect_logits'] gender_probs = torch.softmax(gender_logits, dim=-1).cpu().numpy()[0] dialect_probs = torch.softmax(dialect_logits, dim=-1).cpu().numpy()[0] gender_idx = int(np.argmax(gender_probs)) dialect_idx = int(np.argmax(dialect_probs)) gender_conf = float(gender_probs[gender_idx]) * 100 dialect_conf = float(dialect_probs[dialect_idx]) * 100 gender_result = f"{GENDER_LABELS[gender_idx]} ({gender_conf:.1f}%)" dialect_result = f"{DIALECT_LABELS[dialect_idx]} ({dialect_conf:.1f}%)" return gender_result, dialect_result except Exception as e: import traceback traceback.print_exc() return f"Error: {str(e)}", f"Error: {str(e)}" def get_available_models(self): """Get list of available models.""" return list(self.models.keys()) def create_interface(): """Create Gradio interface with model selection.""" profiler = MultiModelProfiler() available_models = profiler.get_available_models() if not available_models: available_models = ["No models available"] def predict_wrapper(audio, model_name): if audio is None: return "Please upload audio", "Please upload audio" return profiler.predict(audio, model_name) # Create model info text model_info = "" for name, config in MODELS_CONFIG.items(): status = "[OK]" if name in profiler.models else "[X]" model_info += f"{status} **{name}**: {config['description']}\n" with gr.Blocks(title="Vietnamese Speaker Profiling") as demo: gr.Markdown( """ # Vietnamese Speaker Profiling Analyze Vietnamese speech to predict **Gender** and **Dialect Region**. Supports multiple AI models - choose the one that works best for you! """ ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Input") audio_input = gr.Audio( label="Upload or Record Audio", type="filepath" ) model_dropdown = gr.Dropdown( choices=available_models, value=available_models[0] if available_models else None, label="Select Model", info="Choose the AI model for analysis" ) submit_btn = gr.Button("Analyze", variant="primary", size="lg") gr.Markdown("### Available Models") gr.Markdown(model_info) with gr.Column(scale=1): gr.Markdown("### Results") gender_output = gr.Textbox(label="Gender", interactive=False) dialect_output = gr.Textbox(label="Dialect Region", interactive=False) gr.Markdown( """ ### Dialect Regions - **North**: Hanoi and surrounding areas - **Central**: Hue, Da Nang, and Central Vietnam - **South**: Ho Chi Minh City and Southern Vietnam """ ) submit_btn.click( fn=predict_wrapper, inputs=[audio_input, model_dropdown], outputs=[gender_output, dialect_output] ) gr.Markdown( """ --- *Vietnamese Speech Processing Research* """ ) return demo if __name__ == "__main__": demo = create_interface() demo.launch(server_name="0.0.0.0", server_port=7860, share=False)