|
|
""" |
|
|
Vietnamese Speaker Profiling - Multi-Model Gradio Interface |
|
|
Supports: Vietnamese Wav2Vec2 and PhoWhisper encoders |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import librosa |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
from pathlib import Path |
|
|
from safetensors.torch import load_file as load_safetensors |
|
|
|
|
|
|
|
|
MODELS_CONFIG = { |
|
|
"Wav2Vec2 Vietnamese": { |
|
|
"path": "model/vulehuubinh", |
|
|
"encoder_name": "nguyenvulebinh/wav2vec2-base-vi-vlsp2020", |
|
|
"is_whisper": False, |
|
|
"description": "Vietnamese Wav2Vec2 pretrained model - Fast inference" |
|
|
}, |
|
|
"PhoWhisper": { |
|
|
"path": "model/pho", |
|
|
"encoder_name": "vinai/PhoWhisper-base", |
|
|
"is_whisper": True, |
|
|
"description": "Vietnamese Whisper model - Higher accuracy" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
GENDER_LABELS = { |
|
|
0: "Female", |
|
|
1: "Male" |
|
|
} |
|
|
DIALECT_LABELS = { |
|
|
0: "North", |
|
|
1: "Central", |
|
|
2: "South" |
|
|
} |
|
|
|
|
|
|
|
|
class MultiModelProfiler: |
|
|
"""Speaker Profiler supporting multiple encoder models.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.sampling_rate = 16000 |
|
|
self.max_duration = 5 |
|
|
self.models = {} |
|
|
self.processors = {} |
|
|
self.current_model = None |
|
|
|
|
|
print(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
self._load_all_models() |
|
|
|
|
|
def _load_all_models(self): |
|
|
"""Load all available models.""" |
|
|
for model_name, config in MODELS_CONFIG.items(): |
|
|
model_path = Path(config["path"]) |
|
|
if model_path.exists(): |
|
|
print(f"Loading {model_name}...") |
|
|
self._load_single_model(model_name, config) |
|
|
else: |
|
|
print(f"Model not found: {model_path}") |
|
|
|
|
|
def _load_single_model(self, model_name: str, config: dict): |
|
|
"""Load a specific model.""" |
|
|
try: |
|
|
model_path = Path(config["path"]) |
|
|
is_whisper = config["is_whisper"] |
|
|
encoder_name = config["encoder_name"] |
|
|
|
|
|
|
|
|
if is_whisper: |
|
|
from transformers import WhisperFeatureExtractor |
|
|
processor = WhisperFeatureExtractor.from_pretrained(encoder_name) |
|
|
else: |
|
|
from transformers import Wav2Vec2FeatureExtractor |
|
|
processor = Wav2Vec2FeatureExtractor.from_pretrained(encoder_name) |
|
|
|
|
|
|
|
|
from src.models import MultiTaskSpeakerModel |
|
|
|
|
|
|
|
|
checkpoint_path = model_path / "model.safetensors" |
|
|
pt_path = model_path / "best_model.pt" |
|
|
state_dict = None |
|
|
|
|
|
if checkpoint_path.exists(): |
|
|
state_dict = load_safetensors(str(checkpoint_path)) |
|
|
elif pt_path.exists(): |
|
|
checkpoint = torch.load(pt_path, map_location=self.device, weights_only=False) |
|
|
if "model_state_dict" in checkpoint: |
|
|
state_dict = checkpoint["model_state_dict"] |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
|
|
|
|
|
|
head_hidden_dim = 256 |
|
|
if state_dict is not None and "gender_head.0.weight" in state_dict: |
|
|
|
|
|
head_hidden_dim = state_dict["gender_head.0.weight"].shape[0] |
|
|
print(f"Detected head_hidden_dim: {head_hidden_dim}") |
|
|
|
|
|
model = MultiTaskSpeakerModel( |
|
|
model_name=encoder_name, |
|
|
num_genders=2, |
|
|
num_dialects=3, |
|
|
dropout=0.1, |
|
|
head_hidden_dim=head_hidden_dim, |
|
|
freeze_encoder=True |
|
|
) |
|
|
|
|
|
|
|
|
if state_dict is not None: |
|
|
model.load_state_dict(state_dict) |
|
|
print(f"Loaded checkpoint: {checkpoint_path if checkpoint_path.exists() else pt_path}") |
|
|
|
|
|
model.to(self.device) |
|
|
model.eval() |
|
|
|
|
|
self.models[model_name] = model |
|
|
self.processors[model_name] = processor |
|
|
|
|
|
if self.current_model is None: |
|
|
self.current_model = model_name |
|
|
|
|
|
print(f"{model_name} loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading {model_name}: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
def predict(self, audio_path: str, model_name: str): |
|
|
"""Predict gender and dialect from audio.""" |
|
|
if model_name not in self.models: |
|
|
available = list(self.models.keys()) |
|
|
if not available: |
|
|
return "No models available", "No models available" |
|
|
model_name = available[0] |
|
|
|
|
|
try: |
|
|
model = self.models[model_name] |
|
|
processor = self.processors[model_name] |
|
|
is_whisper = MODELS_CONFIG[model_name]["is_whisper"] |
|
|
|
|
|
|
|
|
if is_whisper: |
|
|
max_duration = 30 |
|
|
else: |
|
|
max_duration = self.max_duration |
|
|
|
|
|
|
|
|
waveform, sr = librosa.load(audio_path, sr=self.sampling_rate, mono=True) |
|
|
|
|
|
|
|
|
max_samples = int(max_duration * self.sampling_rate) |
|
|
if len(waveform) > max_samples: |
|
|
waveform = waveform[:max_samples] |
|
|
|
|
|
|
|
|
if is_whisper: |
|
|
|
|
|
whisper_length = self.sampling_rate * 30 |
|
|
if len(waveform) < whisper_length: |
|
|
waveform = np.pad(waveform, (0, whisper_length - len(waveform))) |
|
|
|
|
|
inputs = processor( |
|
|
waveform, |
|
|
sampling_rate=self.sampling_rate, |
|
|
return_tensors="pt" |
|
|
) |
|
|
input_tensor = inputs.input_features.to(self.device) |
|
|
else: |
|
|
|
|
|
inputs = processor( |
|
|
waveform, |
|
|
sampling_rate=self.sampling_rate, |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
) |
|
|
input_tensor = inputs.input_values.to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(input_tensor) |
|
|
gender_logits = outputs['gender_logits'] |
|
|
dialect_logits = outputs['dialect_logits'] |
|
|
|
|
|
gender_probs = torch.softmax(gender_logits, dim=-1).cpu().numpy()[0] |
|
|
dialect_probs = torch.softmax(dialect_logits, dim=-1).cpu().numpy()[0] |
|
|
|
|
|
gender_idx = int(np.argmax(gender_probs)) |
|
|
dialect_idx = int(np.argmax(dialect_probs)) |
|
|
|
|
|
gender_conf = float(gender_probs[gender_idx]) * 100 |
|
|
dialect_conf = float(dialect_probs[dialect_idx]) * 100 |
|
|
|
|
|
gender_result = f"{GENDER_LABELS[gender_idx]} ({gender_conf:.1f}%)" |
|
|
dialect_result = f"{DIALECT_LABELS[dialect_idx]} ({dialect_conf:.1f}%)" |
|
|
|
|
|
return gender_result, dialect_result |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return f"Error: {str(e)}", f"Error: {str(e)}" |
|
|
|
|
|
def get_available_models(self): |
|
|
"""Get list of available models.""" |
|
|
return list(self.models.keys()) |
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
"""Create Gradio interface with model selection.""" |
|
|
|
|
|
profiler = MultiModelProfiler() |
|
|
available_models = profiler.get_available_models() |
|
|
|
|
|
if not available_models: |
|
|
available_models = ["No models available"] |
|
|
|
|
|
def predict_wrapper(audio, model_name): |
|
|
if audio is None: |
|
|
return "Please upload audio", "Please upload audio" |
|
|
return profiler.predict(audio, model_name) |
|
|
|
|
|
|
|
|
model_info = "" |
|
|
for name, config in MODELS_CONFIG.items(): |
|
|
status = "[OK]" if name in profiler.models else "[X]" |
|
|
model_info += f"{status} **{name}**: {config['description']}\n" |
|
|
|
|
|
with gr.Blocks(title="Vietnamese Speaker Profiling") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# Vietnamese Speaker Profiling |
|
|
|
|
|
Analyze Vietnamese speech to predict **Gender** and **Dialect Region**. |
|
|
|
|
|
Supports multiple AI models - choose the one that works best for you! |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Input") |
|
|
audio_input = gr.Audio( |
|
|
label="Upload or Record Audio", |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
model_dropdown = gr.Dropdown( |
|
|
choices=available_models, |
|
|
value=available_models[0] if available_models else None, |
|
|
label="Select Model", |
|
|
info="Choose the AI model for analysis" |
|
|
) |
|
|
|
|
|
submit_btn = gr.Button("Analyze", variant="primary", size="lg") |
|
|
|
|
|
gr.Markdown("### Available Models") |
|
|
gr.Markdown(model_info) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Results") |
|
|
gender_output = gr.Textbox(label="Gender", interactive=False) |
|
|
dialect_output = gr.Textbox(label="Dialect Region", interactive=False) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
### Dialect Regions |
|
|
- **North**: Hanoi and surrounding areas |
|
|
- **Central**: Hue, Da Nang, and Central Vietnam |
|
|
- **South**: Ho Chi Minh City and Southern Vietnam |
|
|
""" |
|
|
) |
|
|
|
|
|
submit_btn.click( |
|
|
fn=predict_wrapper, |
|
|
inputs=[audio_input, model_dropdown], |
|
|
outputs=[gender_output, dialect_output] |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
*Vietnamese Speech Processing Research* |
|
|
""" |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = create_interface() |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |
|
|
|