| import torch |
| import os |
|
|
| |
| PATH_FINETUNE = "/workspace/trainTTS/StyleTTS2_custom/Models/mix5voice/ver2/best_model.pth" |
| PATH_UNIVERSAL = "/workspace/trainTTS/epochs_2nd_00020_universal.pth" |
| PATH_MERGE = "/workspace/trainTTS/StyleTTS2_custom/Models/mix5voice/merge_model.pth" |
|
|
|
|
| def remove_prefix(text, prefix="module."): |
| if text.startswith(prefix): |
| return text[len(prefix):] |
| return text |
|
|
|
|
| def merge_models(): |
| print(f"1. Đang load Finetune Model (Base): {PATH_FINETUNE}") |
| ft_ckpt = torch.load(PATH_FINETUNE, map_location='cpu') |
| ft_net = ft_ckpt['net'] |
|
|
| print(f"2. Đang load Universal Model (New): {PATH_UNIVERSAL}") |
| uni_ckpt = torch.load(PATH_UNIVERSAL, map_location='cpu') |
| uni_net = uni_ckpt['net'] |
|
|
| |
| |
| ft_keys_map = {remove_prefix(k): k for k in ft_net.keys()} |
|
|
| merged_net = ft_net.copy() |
|
|
| print("\n--- BẮT ĐẦU PHẪU THUẬT (AUTO MATCHING) ---") |
|
|
| modules_to_replace = ['style_encoder', 'diffusion'] |
| count_replaced = 0 |
|
|
| |
| for key_uni, val_uni in uni_net.items(): |
| |
| clean_key = remove_prefix(key_uni) |
|
|
| |
| is_target = False |
| for module_name in modules_to_replace: |
| if clean_key.startswith(module_name + "."): |
| is_target = True |
| break |
|
|
| if is_target: |
| |
| if clean_key in ft_keys_map: |
| real_key_ft = ft_keys_map[clean_key] |
|
|
| |
| if merged_net[real_key_ft].shape == val_uni.shape: |
| merged_net[real_key_ft] = val_uni |
| count_replaced += 1 |
| else: |
| print( |
| f"⚠️ Lệch Size (Bỏ qua): {clean_key} | Gốc: {merged_net[real_key_ft].shape} != Mới: {val_uni.shape}") |
| else: |
| |
| pass |
|
|
| if count_replaced == 0: |
| print("❌ VẪN LỖI: Không tìm thấy tham số nào khớp. Hãy kiểm tra lại tên module trong config!") |
| |
| print("Ví dụ key trong Universal:", list(uni_net.keys())[:5]) |
| else: |
| print(f"✅ Đã cấy ghép thành công {count_replaced} tham số!") |
| print(f" (Đã tự động xử lý lệch 'module.' prefix)") |
|
|
| |
| final_state = { |
| 'net': merged_net, |
| 'optimizer': None, |
| 'epoch': 0, |
| 'iters': 0, |
| 'config': ft_ckpt.get('config', {}) |
| } |
|
|
| print(f"\n💾 Đang lưu model ghép tại: {PATH_MERGE}") |
| torch.save(final_state, PATH_MERGE) |
| print("🎉 HOÀN TẤT!") |
|
|
|
|
| if __name__ == "__main__": |
| merge_models() |