| | import torch
|
| | import torch.nn as nn
|
| | import copy
|
| |
|
| |
|
| | INPUT_PATH = "/workspace/StyleTTS2/yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth"
|
| | OUTPUT_PATH = "/workspace/StyleTTS2/yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020_extended_191.pth"
|
| | OLD_DIM = 178
|
| | NEW_DIM = 191
|
| |
|
| |
|
| |
|
| |
|
| | def extend_tensor_logic(tensor):
|
| | """
|
| | Hàm thực hiện logic mở rộng cho một tensor cụ thể.
|
| | Hỗ trợ mở rộng trên bất kỳ trục (axis) nào có kích thước = OLD_DIM.
|
| | """
|
| | shape = list(tensor.shape)
|
| |
|
| |
|
| | if OLD_DIM not in shape:
|
| | return tensor, False
|
| |
|
| |
|
| |
|
| | new_shape = [NEW_DIM if s == OLD_DIM else s for s in shape]
|
| |
|
| |
|
| | if len(shape) == 1:
|
| |
|
| | new_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
| | else:
|
| |
|
| | new_tensor = torch.randn(new_shape, dtype=tensor.dtype, device=tensor.device) * 0.01
|
| |
|
| |
|
| |
|
| |
|
| | slices = [slice(0, s) for s in shape]
|
| | new_tensor[tuple(slices)] = tensor
|
| |
|
| | return new_tensor, True
|
| |
|
| |
|
| | def recursive_scan_and_fix(data, path=""):
|
| | """
|
| | Hàm đệ quy quét qua mọi kiểu dữ liệu (Dict, List, Tensor)
|
| | """
|
| | if isinstance(data, dict):
|
| | new_dict = {}
|
| | for k, v in data.items():
|
| |
|
| | new_dict[k] = recursive_scan_and_fix(v, path + f".{k}")
|
| | return new_dict
|
| |
|
| | elif isinstance(data, list):
|
| |
|
| | return [recursive_scan_and_fix(item, path + f"[{i}]") for i, item in enumerate(data)]
|
| |
|
| | elif isinstance(data, torch.Tensor) or isinstance(data, nn.Parameter):
|
| |
|
| | is_param = isinstance(data, nn.Parameter)
|
| | tensor_data = data.data if is_param else data
|
| |
|
| | new_tensor, modified = extend_tensor_logic(tensor_data)
|
| |
|
| | if modified:
|
| | print(f"✅ Đã sửa: {path}")
|
| | print(f" Shape cũ: {tuple(data.shape)} -> Mới: {tuple(new_tensor.shape)}")
|
| |
|
| |
|
| | if is_param:
|
| | return nn.Parameter(new_tensor)
|
| | return new_tensor
|
| | else:
|
| | return data
|
| |
|
| | else:
|
| |
|
| | return data
|
| |
|
| |
|
| |
|
| | try:
|
| | print(f"⏳ Đang load checkpoint: {INPUT_PATH}")
|
| | checkpoint = torch.load(INPUT_PATH, map_location='cpu')
|
| |
|
| | print(f"\n🚀 Đang quét toàn bộ cấu trúc để tìm dimension {OLD_DIM}...")
|
| |
|
| |
|
| | new_checkpoint = recursive_scan_and_fix(checkpoint, "root")
|
| |
|
| | print("\n💾 Đang lưu file mới...")
|
| | torch.save(new_checkpoint, OUTPUT_PATH)
|
| |
|
| | print(f"🎉 HOÀN TẤT! File đã lưu tại: {OUTPUT_PATH}")
|
| | print(f"👉 Code này có thể áp dụng cho MỌI MODEL có tensor kích thước {OLD_DIM}.")
|
| |
|
| | except Exception as e:
|
| | print(f"❌ LỖI: {e}") |