styletts2-ver2 / model_test.py
hieuducle's picture
Upload full StyleTTS2_custom folder
1b242be verified
# import torch
#
# path = "/u01/colombo/hungnt/hieuld/tts/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth" # đổi thành đường dẫn file của bạn
#
# ckpt = torch.load(path, map_location="cpu")
#
# print("=== Top-level keys in checkpoint ===")
# for k in ckpt['net'].keys():
# print("-", k)
import torch
# Load checkpoint
path = "/workspace/trainTTS/model/epochs_2nd_00020.pth"
# path = "/workspace/trainTTS/dangtr0408/StyleTTS2-lite/Models/base_model.pth"
ckpt = torch.load(path, map_location="cpu")
print("=== Keys trước khi xóa ===")
print(list(ckpt['net'].keys()))
# if 'decoder' in ckpt['net']:
# del ckpt['net']['decoder']
# print("✓ Xóa: decoder")
# if 'predictor' in ckpt['net']:
# del ckpt['net']['predictor']
# print("✓ Xóa: predictor")
# if 'text_encoder' in ckpt['net']:
# del ckpt['net']['text_encoder']
# print("✓ Xóa: text_encoder")
# if 'style_encoder' in ckpt['net']:
# del ckpt['net']['style_encoder']
# print("✓ Xóa: style_encoder")
# if 'text_aligner' in ckpt['net']:
# del ckpt['net']['text_aligner']
# print("✓ Xóa: text_aligner")
# if 'pitch_extractor' in ckpt['net']:
# del ckpt['net']['pitch_extractor']
# print("✓ Xóa: pitch_extractor")
# if 'mpd' in ckpt['net']:
# del ckpt['net']['mpd']
# print("✓ Xóa: mpd")
# if 'msd' in ckpt['net']:
# del ckpt['net']['msd']
# print("✓ Xóa: msd")
# print("\n=== Keys sau khi xóa ===")
# print(list(ckpt['net'].keys()))
# # Save checkpoint mới
# new_path = "/workspace/trainTTS/epoch_4_layer.pth"
# torch.save(ckpt, new_path)
# print(f"\n✓ Lưu checkpoint mới: {new_path}")