styletts2-ver2 / merge_models_v2.py
hieuducle's picture
Upload full StyleTTS2_custom folder
1b242be verified
import torch
import os
# --- CẤU HÌNH ĐƯỜNG DẪN (GIỮ NGUYÊN NHƯ CŨ) ---
PATH_FINETUNE = "/workspace/trainTTS/StyleTTS2_custom/Models/mix5voice/ver2/best_model.pth"
PATH_UNIVERSAL = "/workspace/trainTTS/epochs_2nd_00020_universal.pth"
PATH_MERGE = "/workspace/trainTTS/StyleTTS2_custom/Models/mix5voice/merge_model.pth"
def remove_prefix(text, prefix="module."):
if text.startswith(prefix):
return text[len(prefix):]
return text
def merge_models():
print(f"1. Đang load Finetune Model (Base): {PATH_FINETUNE}")
ft_ckpt = torch.load(PATH_FINETUNE, map_location='cpu')
ft_net = ft_ckpt['net']
print(f"2. Đang load Universal Model (New): {PATH_UNIVERSAL}")
uni_ckpt = torch.load(PATH_UNIVERSAL, map_location='cpu')
uni_net = uni_ckpt['net']
# Tạo bản đồ key chuẩn hóa cho model đích (Finetune)
# Mục đích: Dù key đích là "decoder" hay "module.decoder" ta đều tìm được
ft_keys_map = {remove_prefix(k): k for k in ft_net.keys()}
merged_net = ft_net.copy()
print("\n--- BẮT ĐẦU PHẪU THUẬT (AUTO MATCHING) ---")
modules_to_replace = ['style_encoder', 'diffusion']
count_replaced = 0
# Duyệt qua từng key của model Universal
for key_uni, val_uni in uni_net.items():
# 1. Chuẩn hóa key (bỏ module.)
clean_key = remove_prefix(key_uni)
# 2. Kiểm tra xem key này có thuộc bộ phận cần thay thế không
is_target = False
for module_name in modules_to_replace:
if clean_key.startswith(module_name + "."):
is_target = True
break
if is_target:
# 3. Tìm key tương ứng bên Model Finetune
if clean_key in ft_keys_map:
real_key_ft = ft_keys_map[clean_key]
# 4. Kiểm tra kích thước tensor có khớp không
if merged_net[real_key_ft].shape == val_uni.shape:
merged_net[real_key_ft] = val_uni
count_replaced += 1
else:
print(
f"⚠️ Lệch Size (Bỏ qua): {clean_key} | Gốc: {merged_net[real_key_ft].shape} != Mới: {val_uni.shape}")
else:
# Trường hợp hiếm: Finetune không có key này (có thể do version code khác nhau)
pass
if count_replaced == 0:
print("❌ VẪN LỖI: Không tìm thấy tham số nào khớp. Hãy kiểm tra lại tên module trong config!")
# Debug: In thử 5 key đầu tiên để xem tên nó là gì
print("Ví dụ key trong Universal:", list(uni_net.keys())[:5])
else:
print(f"✅ Đã cấy ghép thành công {count_replaced} tham số!")
print(f" (Đã tự động xử lý lệch 'module.' prefix)")
# Đóng gói
final_state = {
'net': merged_net,
'optimizer': None,
'epoch': 0,
'iters': 0,
'config': ft_ckpt.get('config', {})
}
print(f"\n💾 Đang lưu model ghép tại: {PATH_MERGE}")
torch.save(final_state, PATH_MERGE)
print("🎉 HOÀN TẤT!")
if __name__ == "__main__":
merge_models()