File size: 3,857 Bytes
53ff274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
import copy

# --- CẤU HÌNH ---
INPUT_PATH = "/workspace/StyleTTS2/yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth"  # File gốc
OUTPUT_PATH = "/workspace/StyleTTS2/yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020_extended_191.pth"  # File mới
OLD_DIM = 178  # Số cần tìm
NEW_DIM = 191  # Số mới thay thế


# ----------------

def extend_tensor_logic(tensor):
    """

    Hàm thực hiện logic mở rộng cho một tensor cụ thể.

    Hỗ trợ mở rộng trên bất kỳ trục (axis) nào có kích thước = OLD_DIM.

    """
    shape = list(tensor.shape)

    # Kiểm tra xem tensor này có chiều nào bằng OLD_DIM không
    if OLD_DIM not in shape:
        return tensor, False

    # Xác định trục cần mở rộng (thường là trục 0)
    # Nếu có nhiều trục bằng 178, nó sẽ mở rộng tất cả (hiếm gặp nhưng an toàn)
    new_shape = [NEW_DIM if s == OLD_DIM else s for s in shape]

    # 1. Khởi tạo Tensor mới
    if len(shape) == 1:
        # Nếu là 1D (Bias) -> Điền thêm số 0
        new_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
    else:
        # Nếu là 2D+ (Weight) -> Điền thêm noise nhỏ (std=0.01) để tránh lỗi gradient = 0
        new_tensor = torch.randn(new_shape, dtype=tensor.dtype, device=tensor.device) * 0.01

    # 2. Copy dữ liệu cũ sang
    # Tạo các chỉ số slice để copy đúng vùng dữ liệu
    # Ví dụ: tensor cũ [178, 256], tensor mới [189, 256] -> copy vùng [0:178, :]
    slices = [slice(0, s) for s in shape]
    new_tensor[tuple(slices)] = tensor

    return new_tensor, True


def recursive_scan_and_fix(data, path=""):
    """

    Hàm đệ quy quét qua mọi kiểu dữ liệu (Dict, List, Tensor)

    """
    if isinstance(data, dict):
        new_dict = {}
        for k, v in data.items():
            # Gọi đệ quy cho từng phần tử con
            new_dict[k] = recursive_scan_and_fix(v, path + f".{k}")
        return new_dict

    elif isinstance(data, list):
        # Nếu là list (ví dụ danh sách layer), duyệt qua từng cái
        return [recursive_scan_and_fix(item, path + f"[{i}]") for i, item in enumerate(data)]

    elif isinstance(data, torch.Tensor) or isinstance(data, nn.Parameter):
        # Nếu là Tensor hoặc Parameter -> Kiểm tra và sửa
        is_param = isinstance(data, nn.Parameter)
        tensor_data = data.data if is_param else data

        new_tensor, modified = extend_tensor_logic(tensor_data)

        if modified:
            print(f"✅ Đã sửa: {path}")
            print(f"   Shape cũ: {tuple(data.shape)} -> Mới: {tuple(new_tensor.shape)}")

            # Nếu gốc là Parameter thì bọc lại thành Parameter
            if is_param:
                return nn.Parameter(new_tensor)
            return new_tensor
        else:
            return data

    else:
        # Các kiểu dữ liệu khác (int, string, float...) giữ nguyên
        return data


# --- MAIN EXECUTION ---
try:
    print(f"⏳ Đang load checkpoint: {INPUT_PATH}")
    checkpoint = torch.load(INPUT_PATH, map_location='cpu')

    print(f"\n🚀 Đang quét toàn bộ cấu trúc để tìm dimension {OLD_DIM}...")

    # Bắt đầu quét từ gốc
    new_checkpoint = recursive_scan_and_fix(checkpoint, "root")

    print("\n💾 Đang lưu file mới...")
    torch.save(new_checkpoint, OUTPUT_PATH)

    print(f"🎉 HOÀN TẤT! File đã lưu tại: {OUTPUT_PATH}")
    print(f"👉 Code này có thể áp dụng cho MỌI MODEL có tensor kích thước {OLD_DIM}.")

except Exception as e:
    print(f"❌ LỖI: {e}")