import torch import torch.nn as nn import copy # --- CẤU HÌNH --- INPUT_PATH = "/workspace/StyleTTS2/yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth" # File gốc OUTPUT_PATH = "/workspace/StyleTTS2/yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020_extended_191.pth" # File mới OLD_DIM = 178 # Số cần tìm NEW_DIM = 191 # Số mới thay thế # ---------------- def extend_tensor_logic(tensor): """ Hàm thực hiện logic mở rộng cho một tensor cụ thể. Hỗ trợ mở rộng trên bất kỳ trục (axis) nào có kích thước = OLD_DIM. """ shape = list(tensor.shape) # Kiểm tra xem tensor này có chiều nào bằng OLD_DIM không if OLD_DIM not in shape: return tensor, False # Xác định trục cần mở rộng (thường là trục 0) # Nếu có nhiều trục bằng 178, nó sẽ mở rộng tất cả (hiếm gặp nhưng an toàn) new_shape = [NEW_DIM if s == OLD_DIM else s for s in shape] # 1. Khởi tạo Tensor mới if len(shape) == 1: # Nếu là 1D (Bias) -> Điền thêm số 0 new_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) else: # Nếu là 2D+ (Weight) -> Điền thêm noise nhỏ (std=0.01) để tránh lỗi gradient = 0 new_tensor = torch.randn(new_shape, dtype=tensor.dtype, device=tensor.device) * 0.01 # 2. Copy dữ liệu cũ sang # Tạo các chỉ số slice để copy đúng vùng dữ liệu # Ví dụ: tensor cũ [178, 256], tensor mới [189, 256] -> copy vùng [0:178, :] slices = [slice(0, s) for s in shape] new_tensor[tuple(slices)] = tensor return new_tensor, True def recursive_scan_and_fix(data, path=""): """ Hàm đệ quy quét qua mọi kiểu dữ liệu (Dict, List, Tensor) """ if isinstance(data, dict): new_dict = {} for k, v in data.items(): # Gọi đệ quy cho từng phần tử con new_dict[k] = recursive_scan_and_fix(v, path + f".{k}") return new_dict elif isinstance(data, list): # Nếu là list (ví dụ danh sách layer), duyệt qua từng cái return [recursive_scan_and_fix(item, path + f"[{i}]") for i, item in enumerate(data)] elif isinstance(data, torch.Tensor) or isinstance(data, nn.Parameter): # Nếu là Tensor hoặc Parameter -> Kiểm tra và sửa is_param = isinstance(data, nn.Parameter) tensor_data = data.data if is_param else data new_tensor, modified = extend_tensor_logic(tensor_data) if modified: print(f"✅ Đã sửa: {path}") print(f" Shape cũ: {tuple(data.shape)} -> Mới: {tuple(new_tensor.shape)}") # Nếu gốc là Parameter thì bọc lại thành Parameter if is_param: return nn.Parameter(new_tensor) return new_tensor else: return data else: # Các kiểu dữ liệu khác (int, string, float...) giữ nguyên return data # --- MAIN EXECUTION --- try: print(f"⏳ Đang load checkpoint: {INPUT_PATH}") checkpoint = torch.load(INPUT_PATH, map_location='cpu') print(f"\n🚀 Đang quét toàn bộ cấu trúc để tìm dimension {OLD_DIM}...") # Bắt đầu quét từ gốc new_checkpoint = recursive_scan_and_fix(checkpoint, "root") print("\n💾 Đang lưu file mới...") torch.save(new_checkpoint, OUTPUT_PATH) print(f"🎉 HOÀN TẤT! File đã lưu tại: {OUTPUT_PATH}") print(f"👉 Code này có thể áp dụng cho MỌI MODEL có tensor kích thước {OLD_DIM}.") except Exception as e: print(f"❌ LỖI: {e}")