StyleTTS_dolly / extend_model_datnt.py
hieuducle's picture
Initial upload from script
53ff274 verified
import torch
import torch.nn as nn
import copy
# --- CẤU HÌNH ---
INPUT_PATH = "/workspace/StyleTTS2/yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth" # File gốc
OUTPUT_PATH = "/workspace/StyleTTS2/yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020_extended_191.pth" # File mới
OLD_DIM = 178 # Số cần tìm
NEW_DIM = 191 # Số mới thay thế
# ----------------
def extend_tensor_logic(tensor):
"""
Hàm thực hiện logic mở rộng cho một tensor cụ thể.
Hỗ trợ mở rộng trên bất kỳ trục (axis) nào có kích thước = OLD_DIM.
"""
shape = list(tensor.shape)
# Kiểm tra xem tensor này có chiều nào bằng OLD_DIM không
if OLD_DIM not in shape:
return tensor, False
# Xác định trục cần mở rộng (thường là trục 0)
# Nếu có nhiều trục bằng 178, nó sẽ mở rộng tất cả (hiếm gặp nhưng an toàn)
new_shape = [NEW_DIM if s == OLD_DIM else s for s in shape]
# 1. Khởi tạo Tensor mới
if len(shape) == 1:
# Nếu là 1D (Bias) -> Điền thêm số 0
new_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
else:
# Nếu là 2D+ (Weight) -> Điền thêm noise nhỏ (std=0.01) để tránh lỗi gradient = 0
new_tensor = torch.randn(new_shape, dtype=tensor.dtype, device=tensor.device) * 0.01
# 2. Copy dữ liệu cũ sang
# Tạo các chỉ số slice để copy đúng vùng dữ liệu
# Ví dụ: tensor cũ [178, 256], tensor mới [189, 256] -> copy vùng [0:178, :]
slices = [slice(0, s) for s in shape]
new_tensor[tuple(slices)] = tensor
return new_tensor, True
def recursive_scan_and_fix(data, path=""):
"""
Hàm đệ quy quét qua mọi kiểu dữ liệu (Dict, List, Tensor)
"""
if isinstance(data, dict):
new_dict = {}
for k, v in data.items():
# Gọi đệ quy cho từng phần tử con
new_dict[k] = recursive_scan_and_fix(v, path + f".{k}")
return new_dict
elif isinstance(data, list):
# Nếu là list (ví dụ danh sách layer), duyệt qua từng cái
return [recursive_scan_and_fix(item, path + f"[{i}]") for i, item in enumerate(data)]
elif isinstance(data, torch.Tensor) or isinstance(data, nn.Parameter):
# Nếu là Tensor hoặc Parameter -> Kiểm tra và sửa
is_param = isinstance(data, nn.Parameter)
tensor_data = data.data if is_param else data
new_tensor, modified = extend_tensor_logic(tensor_data)
if modified:
print(f"✅ Đã sửa: {path}")
print(f" Shape cũ: {tuple(data.shape)} -> Mới: {tuple(new_tensor.shape)}")
# Nếu gốc là Parameter thì bọc lại thành Parameter
if is_param:
return nn.Parameter(new_tensor)
return new_tensor
else:
return data
else:
# Các kiểu dữ liệu khác (int, string, float...) giữ nguyên
return data
# --- MAIN EXECUTION ---
try:
print(f"⏳ Đang load checkpoint: {INPUT_PATH}")
checkpoint = torch.load(INPUT_PATH, map_location='cpu')
print(f"\n🚀 Đang quét toàn bộ cấu trúc để tìm dimension {OLD_DIM}...")
# Bắt đầu quét từ gốc
new_checkpoint = recursive_scan_and_fix(checkpoint, "root")
print("\n💾 Đang lưu file mới...")
torch.save(new_checkpoint, OUTPUT_PATH)
print(f"🎉 HOÀN TẤT! File đã lưu tại: {OUTPUT_PATH}")
print(f"👉 Code này có thể áp dụng cho MỌI MODEL có tensor kích thước {OLD_DIM}.")
except Exception as e:
print(f"❌ LỖI: {e}")