| import torch |
| import os |
|
|
| PATH_BASE = "/workspace/trainTTS/StyleTTS2_custom/Models/mix5voice/ver2/epoch_2nd_00003.pth" |
| PATH_NEW = "/workspace/trainTTS/epochs_2nd_00020_universal.pth" |
|
|
| PATH_MERGE = "/workspace/trainTTS/StyleTTS2_custom/Models/mix5voice/ver2/epoch_2nd_001000.pth" |
|
|
| print("PATH_BASE", PATH_BASE) |
|
|
|
|
| def merge_models(): |
| print("Đang load models...") |
| |
| base_ckpt = torch.load(PATH_BASE, map_location='cpu') |
| new_ckpt = torch.load(PATH_NEW, map_location='cpu') |
| print("base_ckpt key", base_ckpt.keys()) |
| |
| base_net = base_ckpt['net'] |
| new_net = new_ckpt['net'] |
|
|
| |
| merged_net = base_net.copy() |
|
|
| print("\n--- BẮT ĐẦU PHẪU THUẬT ---") |
|
|
| |
|
|
| keep_from_base = ['text_encoder', 'decoder', 'predictor', 'predictor_encoder', 'text_aligner', |
| 'pitch_extractor', 'mpd', 'msd', 'wd', 'bert', 'bert_encoder'] |
|
|
| for key in keep_from_base: |
| print(f"✅ Giữ nguyên '{key}' từ Base Model (120k) -> Để đọc chuẩn.") |
| merged_net[key] = base_net[key] |
|
|
| take_from_new = ['diffusion', 'style_encoder'] |
|
|
| for key in take_from_new: |
| if key in new_net: |
| print(f"🔥 Cấy ghép '{key}' từ Model Mới -> Để lấy chất giọng/S2S.") |
| merged_net[key] = new_net[key] |
| else: |
| print(f"⚠️ Cảnh báo: Không tìm thấy '{key}' trong model mới. Giữ nguyên của Base.") |
|
|
| |
| final_state = { |
| 'net': merged_net, |
| 'optimizer': base_ckpt['optimizer'], |
| 'iters': base_ckpt['iters'], |
| 'val_loss': base_ckpt['val_loss'], |
| 'epoch': base_ckpt['epoch'], |
|
|
| } |
|
|
| print(f"\nĐang lưu model ghép tại: {PATH_MERGE}") |
| torch.save(final_state, PATH_MERGE) |
| print("🎉 HOÀN TẤT! Hãy dùng model này để Inference.") |
|
|
|
|
| if __name__ == "__main__": |
| merge_models() |
|
|