StyleTTS2_vi / check_style_encoder.py
hieuducle's picture
Upload folder using huggingface_hub
84f3a60 verified
# check_style_encoder.py
import torch
import torchaudio
import librosa
import numpy as np
from scipy.spatial.distance import cosine
import yaml
import os
from models import *
from utils import *
# Setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def load_model(checkpoint_path, config_path):
"""Load model giống như code inference"""
print(f"Loading config from: {config_path}")
config = yaml.safe_load(open(config_path))
# Import sau khi đã có config
print("Building model...")
text_aligner = load_ASR_models(config['ASR_path'], config['ASR_config'])
pitch_extractor = load_F0_models(config['F0_path'])
from Utils.PLBERT.util import load_plbert
plbert = load_plbert(config['PLBERT_dir'])
model = build_model(recursive_munch(config['model_params']),
text_aligner, pitch_extractor, plbert)
print(f"Loading checkpoint from: {checkpoint_path}")
params = torch.load(checkpoint_path, map_location='cpu')['net']
for key in model:
state_dict = params[key]
# Remove "module." prefix if exists
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith("module."):
new_state_dict[k[len("module."):]] = v
else:
new_state_dict[k] = v
model[key].load_state_dict(new_state_dict, strict=True)
model[key].eval().to(device)
print(f" ✓ Loaded {key}")
return model, config
def preprocess_audio(audio_path):
"""Preprocess audio giống như inference code"""
wave, sr = librosa.load(audio_path, sr=24000)
audio, _ = librosa.effects.trim(wave, top_db=30)
to_mel = torchaudio.transforms.MelSpectrogram(
n_mels=80, n_fft=2048, win_length=1200, hop_length=300
)
mel = to_mel(torch.from_numpy(audio).float())
mel = (torch.log(1e-5 + mel.unsqueeze(0)) - (-4)) / 4
return mel
def extract_style(audio_path, model):
"""Extract style vector"""
mel = preprocess_audio(audio_path).to(device)
with torch.no_grad():
# Extract từ cả 2 encoder
ref_s = model['style_encoder'](mel.unsqueeze(1))
ref_p = model['predictor_encoder'](mel.unsqueeze(1))
return ref_s.cpu().numpy(), ref_p.cpu().numpy()
def compute_similarity_matrix(styles_list1, styles_list2):
"""Compute cosine similarity matrix"""
similarities = []
for s1 in styles_list1:
for s2 in styles_list2:
sim = 1 - cosine(s1.flatten(), s2.flatten())
similarities.append(sim)
return similarities
def main():
# ==================================================
# CẤU HÌNH - CHỈNH LẠI ĐƯỜNG DẪN CỦA BẠN
# ==================================================
checkpoint_path = "/u01/colombo/hungnt/hieuld/tts/StyleTTS2/hieuducle/styletts2-ver2-model-bestmodel/best_model_ver2.pth"
config_path = "/u01/colombo/hungnt/hieuld/tts/styletts2_vastai/Configs/config_ft.yml"
# Test audios - THÊM NHIỀU AUDIO HƠN CHO MỖI SPEAKER!
speaker1_audios = [
"/u01/colombo/hungnt/hieuld/tts/styletts2_vastai/audio_ref/megame.wav",
# Thêm audio khác của cùng speaker nếu có
# "/workspace/trainTTS/StyleTTS2_custom/sangnq_2.wav",
# "/workspace/trainTTS/StyleTTS2_custom/sangnq_3.wav",
]
speaker2_audios = [
"/u01/colombo/hungnt/hieuld/tts/styletts2_vastai/audio_ref/sena30.wav",
# Thêm audio khác của speaker 2 nếu có
# "/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/nu_thoi_su_2.wav",
]
# ==================================================
# LOAD MODEL
# ==================================================
print(f"\n{'='*60}")
print("LOADING MODEL")
print(f"{'='*60}")
print(f"Device: {device}")
try:
model, config = load_model(checkpoint_path, config_path)
print(f"\n✓ Model loaded successfully!")
print(f" Style dim: {config['model_params']['style_dim']}")
except Exception as e:
print(f"\n✗ Error loading model: {e}")
import traceback
traceback.print_exc()
return
# ==================================================
# EXTRACT STYLES
# ==================================================
print(f"\n{'='*60}")
print("EXTRACTING STYLES")
print(f"{'='*60}")
print(f"\nSpeaker 1 ({len(speaker1_audios)} audios):")
spk1_style_encoder = []
spk1_predictor_encoder = []
for i, audio in enumerate(speaker1_audios):
try:
ref_s, ref_p = extract_style(audio, model)
spk1_style_encoder.append(ref_s)
spk1_predictor_encoder.append(ref_p)
print(f" ✓ Audio {i+1}: {os.path.basename(audio)}")
print(f" - Style encoder shape: {ref_s.shape}")
print(f" - Predictor encoder shape: {ref_p.shape}")
except Exception as e:
print(f" ✗ Error: {e}")
print(f"\nSpeaker 2 ({len(speaker2_audios)} audios):")
spk2_style_encoder = []
spk2_predictor_encoder = []
for i, audio in enumerate(speaker2_audios):
try:
ref_s, ref_p = extract_style(audio, model)
spk2_style_encoder.append(ref_s)
spk2_predictor_encoder.append(ref_p)
print(f" ✓ Audio {i+1}: {os.path.basename(audio)}")
print(f" - Style encoder shape: {ref_s.shape}")
print(f" - Predictor encoder shape: {ref_p.shape}")
except Exception as e:
print(f" ✗ Error: {e}")
# ==================================================
# ANALYZE STYLE ENCODER (TIMBRE)
# ==================================================
print(f"\n{'='*60}")
print("STYLE ENCODER ANALYSIS (TIMBRE/MÀU GIỌNG)")
print(f"{'='*60}")
# Within-speaker similarity
if len(spk1_style_encoder) > 1:
print("\n📊 Within-speaker similarity (Speaker 1):")
print(" Target: > 0.90 (same speaker should be very similar)")
for i in range(len(spk1_style_encoder)-1):
sim = 1 - cosine(spk1_style_encoder[i].flatten(),
spk1_style_encoder[i+1].flatten())
status = "✓" if sim > 0.90 else "⚠️" if sim > 0.80 else "✗"
print(f" {status} Audio{i+1} vs Audio{i+2}: {sim:.4f}")
else:
print("\n⚠️ Need 2+ audios from Speaker 1 to check within-speaker similarity")
if len(spk2_style_encoder) > 1:
print("\n📊 Within-speaker similarity (Speaker 2):")
print(" Target: > 0.90")
for i in range(len(spk2_style_encoder)-1):
sim = 1 - cosine(spk2_style_encoder[i].flatten(),
spk2_style_encoder[i+1].flatten())
status = "✓" if sim > 0.90 else "⚠️" if sim > 0.80 else "✗"
print(f" {status} Audio{i+1} vs Audio{i+2}: {sim:.4f}")
else:
print("\n⚠️ Need 2+ audios from Speaker 2 to check within-speaker similarity")
# Cross-speaker similarity (QUAN TRỌNG NHẤT!)
print("\n📊 Cross-speaker similarity (Speaker 1 vs Speaker 2):")
print(" Target: < 0.70 (different speakers should be dissimilar)")
style_similarities = compute_similarity_matrix(spk1_style_encoder,
spk2_style_encoder)
for i, s1 in enumerate(spk1_style_encoder):
for j, s2 in enumerate(spk2_style_encoder):
sim = 1 - cosine(s1.flatten(), s2.flatten())
status = "✓" if sim < 0.70 else "⚠️" if sim < 0.80 else "✗"
print(f" {status} Spk1-audio{i+1} vs Spk2-audio{j+1}: {sim:.4f}")
avg_style_sim = np.mean(style_similarities) if style_similarities else 0
print(f"\n 📈 Average cross-speaker similarity: {avg_style_sim:.4f}")
# ==================================================
# ANALYZE PREDICTOR ENCODER (PROSODY)
# ==================================================
print(f"\n{'='*60}")
print("PREDICTOR ENCODER ANALYSIS (PROSODY/NGỮ ĐIỆU)")
print(f"{'='*60}")
print("\n📊 Cross-speaker similarity (Predictor Encoder):")
print(" Note: Predictor encoder cho prosody, ít ảnh hưởng timbre")
pred_similarities = compute_similarity_matrix(spk1_predictor_encoder,
spk2_predictor_encoder)
for i, s1 in enumerate(spk1_predictor_encoder):
for j, s2 in enumerate(spk2_predictor_encoder):
sim = 1 - cosine(s1.flatten(), s2.flatten())
print(f" - Spk1-audio{i+1} vs Spk2-audio{j+1}: {sim:.4f}")
avg_pred_sim = np.mean(pred_similarities) if pred_similarities else 0
print(f"\n 📈 Average: {avg_pred_sim:.4f}")
# ==================================================
# DIAGNOSIS
# ==================================================
print(f"\n{'='*60}")
print("🔍 DIAGNOSIS")
print(f"{'='*60}")
print(f"\nModel info:")
print(f" - Style dim: {config['model_params']['style_dim']}")
print(f" - Checkpoint: {os.path.basename(checkpoint_path)}")
print(f"\n📊 Results:")
print(f" - Style Encoder cross-speaker sim: {avg_style_sim:.4f}")
print(f" - Predictor Encoder cross-speaker sim: {avg_pred_sim:.4f}")
# Diagnosis style encoder (TIMBRE)
print(f"\n{'='*60}")
if avg_style_sim > 0.85:
print("❌ CRITICAL ISSUE: Style Encoder COLLAPSED!")
print(f"{'='*60}")
print("\n🔴 Problem:")
print(" Style encoder similarity = {:.4f} (TOO HIGH!)".format(avg_style_sim))
print(" → Model học 'average/generic voice' thay vì specific timbre")
print(" → Đây là lý do màu giọng không giống!")
print("\n💡 Solutions:")
print(" 1. RETRAIN with:")
print(" - style_dim: 256 (hoặc 512) - hiện tại: {}".format(
config['model_params']['style_dim']))
print(" - lambda_sty: 5.0")
print(" - diff_epoch: 20")
print(" - joint_epoch: 40")
print("\n 2. Hoặc Fine-tune chỉ style_encoder với contrastive loss")
print(" (freeze tất cả modules khác)")
elif avg_style_sim > 0.75:
print("⚠️ WARNING: Style Encoder có vấn đề!")
print(f"{'='*60}")
print("\n🟡 Problem:")
print(" Style encoder similarity = {:.4f} (HIGH)".format(avg_style_sim))
print(" → Weak speaker discrimination")
print("\n💡 Quick fixes to try:")
print(" 1. Tăng lambda_sty: 5.0 và train thêm 10-20 epochs")
print(" 2. Use multi-reference (3-5 clips) và average styles")
print(" 3. Reference audio dài hơn (8-12s)")
else:
print("✅ Style Encoder OK!")
print(f"{'='*60}")
print("\n🟢 Style encoder có thể phân biệt speakers")
print(" Cross-speaker similarity = {:.4f} (ACCEPTABLE)".format(avg_style_sim))
print("\n💡 Nếu vẫn clone không giống, check:")
print(" 1. Reference audio trong inference:")
print(" - Duration: 5-10s (càng dài càng tốt)")
print(" - Quality: clean, no noise")
print(" - Representative: có nhiều đặc trưng của speaker")
print("\n 2. Diffusion trong inference:")
print(" - Thử giảm num_steps từ 5 → 3")
print(" - Hoặc tăng weight của ref_style:")
print(" s = 0.3 * s_pred + 0.7 * ref_style (thay vì 0.7 + 0.3)")
print("\n 3. Multi-reference averaging:")
print(" - Dùng 3-5 reference clips và average styles")
# Additional info
print(f"\n{'='*60}")
print("📝 Additional Info:")
print(f"{'='*60}")
print("\nTimbre characteristics được encode trong Style Encoder:")
print(" - Formant frequencies (F1, F2, F3)")
print(" - Harmonic structure")
print(" - Breathiness/hoarseness")
print(" - Vocal tract characteristics")
print(" - Nasality")
print(f" → Cần style_dim >= 256 để encode đầy đủ")
print(f" → Hiện tại: style_dim = {config['model_params']['style_dim']}")
if __name__ == "__main__":
main()