styletts2-ver2 / merge_models.py
hieuducle's picture
Upload full StyleTTS2_custom folder
1b242be verified
import torch
import os
PATH_BASE = "/workspace/trainTTS/StyleTTS2_custom/Models/mix5voice/ver2/epoch_2nd_00003.pth"
PATH_NEW = "/workspace/trainTTS/epochs_2nd_00020_universal.pth"
PATH_MERGE = "/workspace/trainTTS/StyleTTS2_custom/Models/mix5voice/ver2/epoch_2nd_001000.pth"
print("PATH_BASE", PATH_BASE)
def merge_models():
print("Đang load models...")
# Load model vào CPU để xử lý
base_ckpt = torch.load(PATH_BASE, map_location='cpu')
new_ckpt = torch.load(PATH_NEW, map_location='cpu')
print("base_ckpt key", base_ckpt.keys())
# dasdasdsad
base_net = base_ckpt['net']
new_net = new_ckpt['net']
# Tạo dict cho model mới dựa trên khung của Base (để đảm bảo cấu trúc chuẩn)
merged_net = base_net.copy()
print("\n--- BẮT ĐẦU PHẪU THUẬT ---")
# DANH SÁCH CÁC BỘ PHẬN
keep_from_base = ['text_encoder', 'decoder', 'predictor', 'predictor_encoder', 'text_aligner',
'pitch_extractor', 'mpd', 'msd', 'wd', 'bert', 'bert_encoder']
for key in keep_from_base:
print(f"✅ Giữ nguyên '{key}' từ Base Model (120k) -> Để đọc chuẩn.")
merged_net[key] = base_net[key]
take_from_new = ['diffusion', 'style_encoder']
for key in take_from_new:
if key in new_net:
print(f"🔥 Cấy ghép '{key}' từ Model Mới -> Để lấy chất giọng/S2S.")
merged_net[key] = new_net[key]
else:
print(f"⚠️ Cảnh báo: Không tìm thấy '{key}' trong model mới. Giữ nguyên của Base.")
# 3. Đóng gói lại
final_state = {
'net': merged_net,
'optimizer': base_ckpt['optimizer'],
'iters': base_ckpt['iters'],
'val_loss': base_ckpt['val_loss'],
'epoch': base_ckpt['epoch'],
}
print(f"\nĐang lưu model ghép tại: {PATH_MERGE}")
torch.save(final_state, PATH_MERGE)
print("🎉 HOÀN TẤT! Hãy dùng model này để Inference.")
if __name__ == "__main__":
merge_models()