# check_style_encoder.py import torch import torchaudio import librosa import numpy as np from scipy.spatial.distance import cosine import yaml import os from models import * from utils import * # Setup device = 'cuda' if torch.cuda.is_available() else 'cpu' def load_model(checkpoint_path, config_path): """Load model giống như code inference""" print(f"Loading config from: {config_path}") config = yaml.safe_load(open(config_path)) # Import sau khi đã có config print("Building model...") text_aligner = load_ASR_models(config['ASR_path'], config['ASR_config']) pitch_extractor = load_F0_models(config['F0_path']) from Utils.PLBERT.util import load_plbert plbert = load_plbert(config['PLBERT_dir']) model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert) print(f"Loading checkpoint from: {checkpoint_path}") params = torch.load(checkpoint_path, map_location='cpu')['net'] for key in model: state_dict = params[key] # Remove "module." prefix if exists new_state_dict = {} for k, v in state_dict.items(): if k.startswith("module."): new_state_dict[k[len("module."):]] = v else: new_state_dict[k] = v model[key].load_state_dict(new_state_dict, strict=True) model[key].eval().to(device) print(f" ✓ Loaded {key}") return model, config def preprocess_audio(audio_path): """Preprocess audio giống như inference code""" wave, sr = librosa.load(audio_path, sr=24000) audio, _ = librosa.effects.trim(wave, top_db=30) to_mel = torchaudio.transforms.MelSpectrogram( n_mels=80, n_fft=2048, win_length=1200, hop_length=300 ) mel = to_mel(torch.from_numpy(audio).float()) mel = (torch.log(1e-5 + mel.unsqueeze(0)) - (-4)) / 4 return mel def extract_style(audio_path, model): """Extract style vector""" mel = preprocess_audio(audio_path).to(device) with torch.no_grad(): # Extract từ cả 2 encoder ref_s = model['style_encoder'](mel.unsqueeze(1)) ref_p = model['predictor_encoder'](mel.unsqueeze(1)) return ref_s.cpu().numpy(), ref_p.cpu().numpy() def compute_similarity_matrix(styles_list1, styles_list2): """Compute cosine similarity matrix""" similarities = [] for s1 in styles_list1: for s2 in styles_list2: sim = 1 - cosine(s1.flatten(), s2.flatten()) similarities.append(sim) return similarities def main(): # ================================================== # CẤU HÌNH - CHỈNH LẠI ĐƯỜNG DẪN CỦA BẠN # ================================================== checkpoint_path = "/u01/colombo/hungnt/hieuld/tts/StyleTTS2/hieuducle/styletts2-ver2-model-bestmodel/best_model_ver2.pth" config_path = "/u01/colombo/hungnt/hieuld/tts/styletts2_vastai/Configs/config_ft.yml" # Test audios - THÊM NHIỀU AUDIO HƠN CHO MỖI SPEAKER! speaker1_audios = [ "/u01/colombo/hungnt/hieuld/tts/styletts2_vastai/audio_ref/megame.wav", # Thêm audio khác của cùng speaker nếu có # "/workspace/trainTTS/StyleTTS2_custom/sangnq_2.wav", # "/workspace/trainTTS/StyleTTS2_custom/sangnq_3.wav", ] speaker2_audios = [ "/u01/colombo/hungnt/hieuld/tts/styletts2_vastai/audio_ref/sena30.wav", # Thêm audio khác của speaker 2 nếu có # "/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/nu_thoi_su_2.wav", ] # ================================================== # LOAD MODEL # ================================================== print(f"\n{'='*60}") print("LOADING MODEL") print(f"{'='*60}") print(f"Device: {device}") try: model, config = load_model(checkpoint_path, config_path) print(f"\n✓ Model loaded successfully!") print(f" Style dim: {config['model_params']['style_dim']}") except Exception as e: print(f"\n✗ Error loading model: {e}") import traceback traceback.print_exc() return # ================================================== # EXTRACT STYLES # ================================================== print(f"\n{'='*60}") print("EXTRACTING STYLES") print(f"{'='*60}") print(f"\nSpeaker 1 ({len(speaker1_audios)} audios):") spk1_style_encoder = [] spk1_predictor_encoder = [] for i, audio in enumerate(speaker1_audios): try: ref_s, ref_p = extract_style(audio, model) spk1_style_encoder.append(ref_s) spk1_predictor_encoder.append(ref_p) print(f" ✓ Audio {i+1}: {os.path.basename(audio)}") print(f" - Style encoder shape: {ref_s.shape}") print(f" - Predictor encoder shape: {ref_p.shape}") except Exception as e: print(f" ✗ Error: {e}") print(f"\nSpeaker 2 ({len(speaker2_audios)} audios):") spk2_style_encoder = [] spk2_predictor_encoder = [] for i, audio in enumerate(speaker2_audios): try: ref_s, ref_p = extract_style(audio, model) spk2_style_encoder.append(ref_s) spk2_predictor_encoder.append(ref_p) print(f" ✓ Audio {i+1}: {os.path.basename(audio)}") print(f" - Style encoder shape: {ref_s.shape}") print(f" - Predictor encoder shape: {ref_p.shape}") except Exception as e: print(f" ✗ Error: {e}") # ================================================== # ANALYZE STYLE ENCODER (TIMBRE) # ================================================== print(f"\n{'='*60}") print("STYLE ENCODER ANALYSIS (TIMBRE/MÀU GIỌNG)") print(f"{'='*60}") # Within-speaker similarity if len(spk1_style_encoder) > 1: print("\n📊 Within-speaker similarity (Speaker 1):") print(" Target: > 0.90 (same speaker should be very similar)") for i in range(len(spk1_style_encoder)-1): sim = 1 - cosine(spk1_style_encoder[i].flatten(), spk1_style_encoder[i+1].flatten()) status = "✓" if sim > 0.90 else "⚠️" if sim > 0.80 else "✗" print(f" {status} Audio{i+1} vs Audio{i+2}: {sim:.4f}") else: print("\n⚠️ Need 2+ audios from Speaker 1 to check within-speaker similarity") if len(spk2_style_encoder) > 1: print("\n📊 Within-speaker similarity (Speaker 2):") print(" Target: > 0.90") for i in range(len(spk2_style_encoder)-1): sim = 1 - cosine(spk2_style_encoder[i].flatten(), spk2_style_encoder[i+1].flatten()) status = "✓" if sim > 0.90 else "⚠️" if sim > 0.80 else "✗" print(f" {status} Audio{i+1} vs Audio{i+2}: {sim:.4f}") else: print("\n⚠️ Need 2+ audios from Speaker 2 to check within-speaker similarity") # Cross-speaker similarity (QUAN TRỌNG NHẤT!) print("\n📊 Cross-speaker similarity (Speaker 1 vs Speaker 2):") print(" Target: < 0.70 (different speakers should be dissimilar)") style_similarities = compute_similarity_matrix(spk1_style_encoder, spk2_style_encoder) for i, s1 in enumerate(spk1_style_encoder): for j, s2 in enumerate(spk2_style_encoder): sim = 1 - cosine(s1.flatten(), s2.flatten()) status = "✓" if sim < 0.70 else "⚠️" if sim < 0.80 else "✗" print(f" {status} Spk1-audio{i+1} vs Spk2-audio{j+1}: {sim:.4f}") avg_style_sim = np.mean(style_similarities) if style_similarities else 0 print(f"\n 📈 Average cross-speaker similarity: {avg_style_sim:.4f}") # ================================================== # ANALYZE PREDICTOR ENCODER (PROSODY) # ================================================== print(f"\n{'='*60}") print("PREDICTOR ENCODER ANALYSIS (PROSODY/NGỮ ĐIỆU)") print(f"{'='*60}") print("\n📊 Cross-speaker similarity (Predictor Encoder):") print(" Note: Predictor encoder cho prosody, ít ảnh hưởng timbre") pred_similarities = compute_similarity_matrix(spk1_predictor_encoder, spk2_predictor_encoder) for i, s1 in enumerate(spk1_predictor_encoder): for j, s2 in enumerate(spk2_predictor_encoder): sim = 1 - cosine(s1.flatten(), s2.flatten()) print(f" - Spk1-audio{i+1} vs Spk2-audio{j+1}: {sim:.4f}") avg_pred_sim = np.mean(pred_similarities) if pred_similarities else 0 print(f"\n 📈 Average: {avg_pred_sim:.4f}") # ================================================== # DIAGNOSIS # ================================================== print(f"\n{'='*60}") print("🔍 DIAGNOSIS") print(f"{'='*60}") print(f"\nModel info:") print(f" - Style dim: {config['model_params']['style_dim']}") print(f" - Checkpoint: {os.path.basename(checkpoint_path)}") print(f"\n📊 Results:") print(f" - Style Encoder cross-speaker sim: {avg_style_sim:.4f}") print(f" - Predictor Encoder cross-speaker sim: {avg_pred_sim:.4f}") # Diagnosis style encoder (TIMBRE) print(f"\n{'='*60}") if avg_style_sim > 0.85: print("❌ CRITICAL ISSUE: Style Encoder COLLAPSED!") print(f"{'='*60}") print("\n🔴 Problem:") print(" Style encoder similarity = {:.4f} (TOO HIGH!)".format(avg_style_sim)) print(" → Model học 'average/generic voice' thay vì specific timbre") print(" → Đây là lý do màu giọng không giống!") print("\n💡 Solutions:") print(" 1. RETRAIN with:") print(" - style_dim: 256 (hoặc 512) - hiện tại: {}".format( config['model_params']['style_dim'])) print(" - lambda_sty: 5.0") print(" - diff_epoch: 20") print(" - joint_epoch: 40") print("\n 2. Hoặc Fine-tune chỉ style_encoder với contrastive loss") print(" (freeze tất cả modules khác)") elif avg_style_sim > 0.75: print("⚠️ WARNING: Style Encoder có vấn đề!") print(f"{'='*60}") print("\n🟡 Problem:") print(" Style encoder similarity = {:.4f} (HIGH)".format(avg_style_sim)) print(" → Weak speaker discrimination") print("\n💡 Quick fixes to try:") print(" 1. Tăng lambda_sty: 5.0 và train thêm 10-20 epochs") print(" 2. Use multi-reference (3-5 clips) và average styles") print(" 3. Reference audio dài hơn (8-12s)") else: print("✅ Style Encoder OK!") print(f"{'='*60}") print("\n🟢 Style encoder có thể phân biệt speakers") print(" Cross-speaker similarity = {:.4f} (ACCEPTABLE)".format(avg_style_sim)) print("\n💡 Nếu vẫn clone không giống, check:") print(" 1. Reference audio trong inference:") print(" - Duration: 5-10s (càng dài càng tốt)") print(" - Quality: clean, no noise") print(" - Representative: có nhiều đặc trưng của speaker") print("\n 2. Diffusion trong inference:") print(" - Thử giảm num_steps từ 5 → 3") print(" - Hoặc tăng weight của ref_style:") print(" s = 0.3 * s_pred + 0.7 * ref_style (thay vì 0.7 + 0.3)") print("\n 3. Multi-reference averaging:") print(" - Dùng 3-5 reference clips và average styles") # Additional info print(f"\n{'='*60}") print("📝 Additional Info:") print(f"{'='*60}") print("\nTimbre characteristics được encode trong Style Encoder:") print(" - Formant frequencies (F1, F2, F3)") print(" - Harmonic structure") print(" - Breathiness/hoarseness") print(" - Vocal tract characteristics") print(" - Nasality") print(f" → Cần style_dim >= 256 để encode đầy đủ") print(f" → Hiện tại: style_dim = {config['model_params']['style_dim']}") if __name__ == "__main__": main()