Spaces:
Running
Running
| # check_style_encoder.py | |
| import torch | |
| import torchaudio | |
| import librosa | |
| import numpy as np | |
| from scipy.spatial.distance import cosine | |
| import yaml | |
| import os | |
| from models import * | |
| from utils import * | |
| # Setup | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def load_model(checkpoint_path, config_path): | |
| """Load model giống như code inference""" | |
| print(f"Loading config from: {config_path}") | |
| config = yaml.safe_load(open(config_path)) | |
| # Import sau khi đã có config | |
| print("Building model...") | |
| text_aligner = load_ASR_models(config['ASR_path'], config['ASR_config']) | |
| pitch_extractor = load_F0_models(config['F0_path']) | |
| from Utils.PLBERT.util import load_plbert | |
| plbert = load_plbert(config['PLBERT_dir']) | |
| model = build_model(recursive_munch(config['model_params']), | |
| text_aligner, pitch_extractor, plbert) | |
| print(f"Loading checkpoint from: {checkpoint_path}") | |
| params = torch.load(checkpoint_path, map_location='cpu')['net'] | |
| for key in model: | |
| state_dict = params[key] | |
| # Remove "module." prefix if exists | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith("module."): | |
| new_state_dict[k[len("module."):]] = v | |
| else: | |
| new_state_dict[k] = v | |
| model[key].load_state_dict(new_state_dict, strict=True) | |
| model[key].eval().to(device) | |
| print(f" ✓ Loaded {key}") | |
| return model, config | |
| def preprocess_audio(audio_path): | |
| """Preprocess audio giống như inference code""" | |
| wave, sr = librosa.load(audio_path, sr=24000) | |
| audio, _ = librosa.effects.trim(wave, top_db=30) | |
| to_mel = torchaudio.transforms.MelSpectrogram( | |
| n_mels=80, n_fft=2048, win_length=1200, hop_length=300 | |
| ) | |
| mel = to_mel(torch.from_numpy(audio).float()) | |
| mel = (torch.log(1e-5 + mel.unsqueeze(0)) - (-4)) / 4 | |
| return mel | |
| def extract_style(audio_path, model): | |
| """Extract style vector""" | |
| mel = preprocess_audio(audio_path).to(device) | |
| with torch.no_grad(): | |
| # Extract từ cả 2 encoder | |
| ref_s = model['style_encoder'](mel.unsqueeze(1)) | |
| ref_p = model['predictor_encoder'](mel.unsqueeze(1)) | |
| return ref_s.cpu().numpy(), ref_p.cpu().numpy() | |
| def compute_similarity_matrix(styles_list1, styles_list2): | |
| """Compute cosine similarity matrix""" | |
| similarities = [] | |
| for s1 in styles_list1: | |
| for s2 in styles_list2: | |
| sim = 1 - cosine(s1.flatten(), s2.flatten()) | |
| similarities.append(sim) | |
| return similarities | |
| def main(): | |
| # ================================================== | |
| # CẤU HÌNH - CHỈNH LẠI ĐƯỜNG DẪN CỦA BẠN | |
| # ================================================== | |
| checkpoint_path = "/u01/colombo/hungnt/hieuld/tts/StyleTTS2/hieuducle/styletts2-ver2-model-bestmodel/best_model_ver2.pth" | |
| config_path = "/u01/colombo/hungnt/hieuld/tts/styletts2_vastai/Configs/config_ft.yml" | |
| # Test audios - THÊM NHIỀU AUDIO HƠN CHO MỖI SPEAKER! | |
| speaker1_audios = [ | |
| "/u01/colombo/hungnt/hieuld/tts/styletts2_vastai/audio_ref/megame.wav", | |
| # Thêm audio khác của cùng speaker nếu có | |
| # "/workspace/trainTTS/StyleTTS2_custom/sangnq_2.wav", | |
| # "/workspace/trainTTS/StyleTTS2_custom/sangnq_3.wav", | |
| ] | |
| speaker2_audios = [ | |
| "/u01/colombo/hungnt/hieuld/tts/styletts2_vastai/audio_ref/sena30.wav", | |
| # Thêm audio khác của speaker 2 nếu có | |
| # "/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/nu_thoi_su_2.wav", | |
| ] | |
| # ================================================== | |
| # LOAD MODEL | |
| # ================================================== | |
| print(f"\n{'='*60}") | |
| print("LOADING MODEL") | |
| print(f"{'='*60}") | |
| print(f"Device: {device}") | |
| try: | |
| model, config = load_model(checkpoint_path, config_path) | |
| print(f"\n✓ Model loaded successfully!") | |
| print(f" Style dim: {config['model_params']['style_dim']}") | |
| except Exception as e: | |
| print(f"\n✗ Error loading model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| # ================================================== | |
| # EXTRACT STYLES | |
| # ================================================== | |
| print(f"\n{'='*60}") | |
| print("EXTRACTING STYLES") | |
| print(f"{'='*60}") | |
| print(f"\nSpeaker 1 ({len(speaker1_audios)} audios):") | |
| spk1_style_encoder = [] | |
| spk1_predictor_encoder = [] | |
| for i, audio in enumerate(speaker1_audios): | |
| try: | |
| ref_s, ref_p = extract_style(audio, model) | |
| spk1_style_encoder.append(ref_s) | |
| spk1_predictor_encoder.append(ref_p) | |
| print(f" ✓ Audio {i+1}: {os.path.basename(audio)}") | |
| print(f" - Style encoder shape: {ref_s.shape}") | |
| print(f" - Predictor encoder shape: {ref_p.shape}") | |
| except Exception as e: | |
| print(f" ✗ Error: {e}") | |
| print(f"\nSpeaker 2 ({len(speaker2_audios)} audios):") | |
| spk2_style_encoder = [] | |
| spk2_predictor_encoder = [] | |
| for i, audio in enumerate(speaker2_audios): | |
| try: | |
| ref_s, ref_p = extract_style(audio, model) | |
| spk2_style_encoder.append(ref_s) | |
| spk2_predictor_encoder.append(ref_p) | |
| print(f" ✓ Audio {i+1}: {os.path.basename(audio)}") | |
| print(f" - Style encoder shape: {ref_s.shape}") | |
| print(f" - Predictor encoder shape: {ref_p.shape}") | |
| except Exception as e: | |
| print(f" ✗ Error: {e}") | |
| # ================================================== | |
| # ANALYZE STYLE ENCODER (TIMBRE) | |
| # ================================================== | |
| print(f"\n{'='*60}") | |
| print("STYLE ENCODER ANALYSIS (TIMBRE/MÀU GIỌNG)") | |
| print(f"{'='*60}") | |
| # Within-speaker similarity | |
| if len(spk1_style_encoder) > 1: | |
| print("\n📊 Within-speaker similarity (Speaker 1):") | |
| print(" Target: > 0.90 (same speaker should be very similar)") | |
| for i in range(len(spk1_style_encoder)-1): | |
| sim = 1 - cosine(spk1_style_encoder[i].flatten(), | |
| spk1_style_encoder[i+1].flatten()) | |
| status = "✓" if sim > 0.90 else "⚠️" if sim > 0.80 else "✗" | |
| print(f" {status} Audio{i+1} vs Audio{i+2}: {sim:.4f}") | |
| else: | |
| print("\n⚠️ Need 2+ audios from Speaker 1 to check within-speaker similarity") | |
| if len(spk2_style_encoder) > 1: | |
| print("\n📊 Within-speaker similarity (Speaker 2):") | |
| print(" Target: > 0.90") | |
| for i in range(len(spk2_style_encoder)-1): | |
| sim = 1 - cosine(spk2_style_encoder[i].flatten(), | |
| spk2_style_encoder[i+1].flatten()) | |
| status = "✓" if sim > 0.90 else "⚠️" if sim > 0.80 else "✗" | |
| print(f" {status} Audio{i+1} vs Audio{i+2}: {sim:.4f}") | |
| else: | |
| print("\n⚠️ Need 2+ audios from Speaker 2 to check within-speaker similarity") | |
| # Cross-speaker similarity (QUAN TRỌNG NHẤT!) | |
| print("\n📊 Cross-speaker similarity (Speaker 1 vs Speaker 2):") | |
| print(" Target: < 0.70 (different speakers should be dissimilar)") | |
| style_similarities = compute_similarity_matrix(spk1_style_encoder, | |
| spk2_style_encoder) | |
| for i, s1 in enumerate(spk1_style_encoder): | |
| for j, s2 in enumerate(spk2_style_encoder): | |
| sim = 1 - cosine(s1.flatten(), s2.flatten()) | |
| status = "✓" if sim < 0.70 else "⚠️" if sim < 0.80 else "✗" | |
| print(f" {status} Spk1-audio{i+1} vs Spk2-audio{j+1}: {sim:.4f}") | |
| avg_style_sim = np.mean(style_similarities) if style_similarities else 0 | |
| print(f"\n 📈 Average cross-speaker similarity: {avg_style_sim:.4f}") | |
| # ================================================== | |
| # ANALYZE PREDICTOR ENCODER (PROSODY) | |
| # ================================================== | |
| print(f"\n{'='*60}") | |
| print("PREDICTOR ENCODER ANALYSIS (PROSODY/NGỮ ĐIỆU)") | |
| print(f"{'='*60}") | |
| print("\n📊 Cross-speaker similarity (Predictor Encoder):") | |
| print(" Note: Predictor encoder cho prosody, ít ảnh hưởng timbre") | |
| pred_similarities = compute_similarity_matrix(spk1_predictor_encoder, | |
| spk2_predictor_encoder) | |
| for i, s1 in enumerate(spk1_predictor_encoder): | |
| for j, s2 in enumerate(spk2_predictor_encoder): | |
| sim = 1 - cosine(s1.flatten(), s2.flatten()) | |
| print(f" - Spk1-audio{i+1} vs Spk2-audio{j+1}: {sim:.4f}") | |
| avg_pred_sim = np.mean(pred_similarities) if pred_similarities else 0 | |
| print(f"\n 📈 Average: {avg_pred_sim:.4f}") | |
| # ================================================== | |
| # DIAGNOSIS | |
| # ================================================== | |
| print(f"\n{'='*60}") | |
| print("🔍 DIAGNOSIS") | |
| print(f"{'='*60}") | |
| print(f"\nModel info:") | |
| print(f" - Style dim: {config['model_params']['style_dim']}") | |
| print(f" - Checkpoint: {os.path.basename(checkpoint_path)}") | |
| print(f"\n📊 Results:") | |
| print(f" - Style Encoder cross-speaker sim: {avg_style_sim:.4f}") | |
| print(f" - Predictor Encoder cross-speaker sim: {avg_pred_sim:.4f}") | |
| # Diagnosis style encoder (TIMBRE) | |
| print(f"\n{'='*60}") | |
| if avg_style_sim > 0.85: | |
| print("❌ CRITICAL ISSUE: Style Encoder COLLAPSED!") | |
| print(f"{'='*60}") | |
| print("\n🔴 Problem:") | |
| print(" Style encoder similarity = {:.4f} (TOO HIGH!)".format(avg_style_sim)) | |
| print(" → Model học 'average/generic voice' thay vì specific timbre") | |
| print(" → Đây là lý do màu giọng không giống!") | |
| print("\n💡 Solutions:") | |
| print(" 1. RETRAIN with:") | |
| print(" - style_dim: 256 (hoặc 512) - hiện tại: {}".format( | |
| config['model_params']['style_dim'])) | |
| print(" - lambda_sty: 5.0") | |
| print(" - diff_epoch: 20") | |
| print(" - joint_epoch: 40") | |
| print("\n 2. Hoặc Fine-tune chỉ style_encoder với contrastive loss") | |
| print(" (freeze tất cả modules khác)") | |
| elif avg_style_sim > 0.75: | |
| print("⚠️ WARNING: Style Encoder có vấn đề!") | |
| print(f"{'='*60}") | |
| print("\n🟡 Problem:") | |
| print(" Style encoder similarity = {:.4f} (HIGH)".format(avg_style_sim)) | |
| print(" → Weak speaker discrimination") | |
| print("\n💡 Quick fixes to try:") | |
| print(" 1. Tăng lambda_sty: 5.0 và train thêm 10-20 epochs") | |
| print(" 2. Use multi-reference (3-5 clips) và average styles") | |
| print(" 3. Reference audio dài hơn (8-12s)") | |
| else: | |
| print("✅ Style Encoder OK!") | |
| print(f"{'='*60}") | |
| print("\n🟢 Style encoder có thể phân biệt speakers") | |
| print(" Cross-speaker similarity = {:.4f} (ACCEPTABLE)".format(avg_style_sim)) | |
| print("\n💡 Nếu vẫn clone không giống, check:") | |
| print(" 1. Reference audio trong inference:") | |
| print(" - Duration: 5-10s (càng dài càng tốt)") | |
| print(" - Quality: clean, no noise") | |
| print(" - Representative: có nhiều đặc trưng của speaker") | |
| print("\n 2. Diffusion trong inference:") | |
| print(" - Thử giảm num_steps từ 5 → 3") | |
| print(" - Hoặc tăng weight của ref_style:") | |
| print(" s = 0.3 * s_pred + 0.7 * ref_style (thay vì 0.7 + 0.3)") | |
| print("\n 3. Multi-reference averaging:") | |
| print(" - Dùng 3-5 reference clips và average styles") | |
| # Additional info | |
| print(f"\n{'='*60}") | |
| print("📝 Additional Info:") | |
| print(f"{'='*60}") | |
| print("\nTimbre characteristics được encode trong Style Encoder:") | |
| print(" - Formant frequencies (F1, F2, F3)") | |
| print(" - Harmonic structure") | |
| print(" - Breathiness/hoarseness") | |
| print(" - Vocal tract characteristics") | |
| print(" - Nasality") | |
| print(f" → Cần style_dim >= 256 để encode đầy đủ") | |
| print(f" → Hiện tại: style_dim = {config['model_params']['style_dim']}") | |
| if __name__ == "__main__": | |
| main() |