Thanh-Lam's picture
Enhance model loading: auto-detect head_hidden_dim from checkpoint and streamline checkpoint loading process
b0cfc60
"""
Vietnamese Speaker Profiling - Multi-Model Gradio Interface
Supports: Vietnamese Wav2Vec2 and PhoWhisper encoders
"""
import os
import torch
import librosa
import numpy as np
import gradio as gr
from pathlib import Path
from safetensors.torch import load_file as load_safetensors
# Model configurations
MODELS_CONFIG = {
"Wav2Vec2 Vietnamese": {
"path": "model/vulehuubinh",
"encoder_name": "nguyenvulebinh/wav2vec2-base-vi-vlsp2020",
"is_whisper": False,
"description": "Vietnamese Wav2Vec2 pretrained model - Fast inference"
},
"PhoWhisper": {
"path": "model/pho",
"encoder_name": "vinai/PhoWhisper-base",
"is_whisper": True,
"description": "Vietnamese Whisper model - Higher accuracy"
}
}
# Labels - IMPORTANT: Must match training order!
# Model was trained with Female=0, Male=1
GENDER_LABELS = {
0: "Female",
1: "Male"
}
DIALECT_LABELS = {
0: "North",
1: "Central",
2: "South"
}
class MultiModelProfiler:
"""Speaker Profiler supporting multiple encoder models."""
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.sampling_rate = 16000
self.max_duration = 5 # seconds for non-whisper models
self.models = {}
self.processors = {}
self.current_model = None
print(f"Using device: {self.device}")
# Pre-load all models
self._load_all_models()
def _load_all_models(self):
"""Load all available models."""
for model_name, config in MODELS_CONFIG.items():
model_path = Path(config["path"])
if model_path.exists():
print(f"Loading {model_name}...")
self._load_single_model(model_name, config)
else:
print(f"Model not found: {model_path}")
def _load_single_model(self, model_name: str, config: dict):
"""Load a specific model."""
try:
model_path = Path(config["path"])
is_whisper = config["is_whisper"]
encoder_name = config["encoder_name"]
# Load processor
if is_whisper:
from transformers import WhisperFeatureExtractor
processor = WhisperFeatureExtractor.from_pretrained(encoder_name)
else:
from transformers import Wav2Vec2FeatureExtractor
processor = Wav2Vec2FeatureExtractor.from_pretrained(encoder_name)
# Load model - use MultiTaskSpeakerModel
from src.models import MultiTaskSpeakerModel
# Load checkpoint first to detect head_hidden_dim
checkpoint_path = model_path / "model.safetensors"
pt_path = model_path / "best_model.pt"
state_dict = None
if checkpoint_path.exists():
state_dict = load_safetensors(str(checkpoint_path))
elif pt_path.exists():
checkpoint = torch.load(pt_path, map_location=self.device, weights_only=False)
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
# Auto-detect head_hidden_dim from checkpoint
head_hidden_dim = 256 # default
if state_dict is not None and "gender_head.0.weight" in state_dict:
# gender_head.0.weight has shape [head_hidden_dim, hidden_size]
head_hidden_dim = state_dict["gender_head.0.weight"].shape[0]
print(f"Detected head_hidden_dim: {head_hidden_dim}")
model = MultiTaskSpeakerModel(
model_name=encoder_name,
num_genders=2,
num_dialects=3,
dropout=0.1,
head_hidden_dim=head_hidden_dim,
freeze_encoder=True
)
# Load checkpoint weights
if state_dict is not None:
model.load_state_dict(state_dict)
print(f"Loaded checkpoint: {checkpoint_path if checkpoint_path.exists() else pt_path}")
model.to(self.device)
model.eval()
self.models[model_name] = model
self.processors[model_name] = processor
if self.current_model is None:
self.current_model = model_name
print(f"{model_name} loaded successfully")
except Exception as e:
print(f"Error loading {model_name}: {e}")
import traceback
traceback.print_exc()
def predict(self, audio_path: str, model_name: str):
"""Predict gender and dialect from audio."""
if model_name not in self.models:
available = list(self.models.keys())
if not available:
return "No models available", "No models available"
model_name = available[0]
try:
model = self.models[model_name]
processor = self.processors[model_name]
is_whisper = MODELS_CONFIG[model_name]["is_whisper"]
# Set max duration based on model type
if is_whisper:
max_duration = 30 # Whisper requires 30 seconds
else:
max_duration = self.max_duration
# Load audio using librosa
waveform, sr = librosa.load(audio_path, sr=self.sampling_rate, mono=True)
# Trim to max duration
max_samples = int(max_duration * self.sampling_rate)
if len(waveform) > max_samples:
waveform = waveform[:max_samples]
# Process based on model type
if is_whisper:
# Whisper requires exactly 30 seconds - pad if needed
whisper_length = self.sampling_rate * 30
if len(waveform) < whisper_length:
waveform = np.pad(waveform, (0, whisper_length - len(waveform)))
inputs = processor(
waveform,
sampling_rate=self.sampling_rate,
return_tensors="pt"
)
input_tensor = inputs.input_features.to(self.device)
else:
# Wav2Vec2 uses raw waveform
inputs = processor(
waveform,
sampling_rate=self.sampling_rate,
return_tensors="pt",
padding=True
)
input_tensor = inputs.input_values.to(self.device)
# Inference
with torch.no_grad():
outputs = model(input_tensor)
gender_logits = outputs['gender_logits']
dialect_logits = outputs['dialect_logits']
gender_probs = torch.softmax(gender_logits, dim=-1).cpu().numpy()[0]
dialect_probs = torch.softmax(dialect_logits, dim=-1).cpu().numpy()[0]
gender_idx = int(np.argmax(gender_probs))
dialect_idx = int(np.argmax(dialect_probs))
gender_conf = float(gender_probs[gender_idx]) * 100
dialect_conf = float(dialect_probs[dialect_idx]) * 100
gender_result = f"{GENDER_LABELS[gender_idx]} ({gender_conf:.1f}%)"
dialect_result = f"{DIALECT_LABELS[dialect_idx]} ({dialect_conf:.1f}%)"
return gender_result, dialect_result
except Exception as e:
import traceback
traceback.print_exc()
return f"Error: {str(e)}", f"Error: {str(e)}"
def get_available_models(self):
"""Get list of available models."""
return list(self.models.keys())
def create_interface():
"""Create Gradio interface with model selection."""
profiler = MultiModelProfiler()
available_models = profiler.get_available_models()
if not available_models:
available_models = ["No models available"]
def predict_wrapper(audio, model_name):
if audio is None:
return "Please upload audio", "Please upload audio"
return profiler.predict(audio, model_name)
# Create model info text
model_info = ""
for name, config in MODELS_CONFIG.items():
status = "[OK]" if name in profiler.models else "[X]"
model_info += f"{status} **{name}**: {config['description']}\n"
with gr.Blocks(title="Vietnamese Speaker Profiling") as demo:
gr.Markdown(
"""
# Vietnamese Speaker Profiling
Analyze Vietnamese speech to predict **Gender** and **Dialect Region**.
Supports multiple AI models - choose the one that works best for you!
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Input")
audio_input = gr.Audio(
label="Upload or Record Audio",
type="filepath"
)
model_dropdown = gr.Dropdown(
choices=available_models,
value=available_models[0] if available_models else None,
label="Select Model",
info="Choose the AI model for analysis"
)
submit_btn = gr.Button("Analyze", variant="primary", size="lg")
gr.Markdown("### Available Models")
gr.Markdown(model_info)
with gr.Column(scale=1):
gr.Markdown("### Results")
gender_output = gr.Textbox(label="Gender", interactive=False)
dialect_output = gr.Textbox(label="Dialect Region", interactive=False)
gr.Markdown(
"""
### Dialect Regions
- **North**: Hanoi and surrounding areas
- **Central**: Hue, Da Nang, and Central Vietnam
- **South**: Ho Chi Minh City and Southern Vietnam
"""
)
submit_btn.click(
fn=predict_wrapper,
inputs=[audio_input, model_dropdown],
outputs=[gender_output, dialect_output]
)
gr.Markdown(
"""
---
*Vietnamese Speech Processing Research*
"""
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)