CSATv2 / tar2bin.py
sosigikiller's picture
initial push
9cf6c45
raw
history blame
776 Bytes
import torch
from collections import OrderedDict
ckpt_path = "./CSAT_RCKD.pth.tar"
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
# 1) state_dict ๊บผ๋‚ด๊ธฐ
# - ๋ณดํ†ต {'state_dict': ...} ํ˜•ํƒœ๋‹ˆ๊นŒ ๋จผ์ € ์ด๊ฑธ ์‹œ๋„ํ•˜๊ณ ,
# - ์•„๋‹ˆ๋ฉด ๊ทธ๋ƒฅ ckpt ์ „์ฒด๊ฐ€ state_dict์ธ ๊ฒฝ์šฐ๋„ ์žˆ์–ด์„œ fallback
state_dict = ckpt.get("state_dict", ckpt)
# 2) DataParallel ์ผ์œผ๋ฉด key ์•ž์— 'module.' ๋ถ™์–ด์žˆ์„ ์ˆ˜ ์žˆ์–ด์„œ ์ œ๊ฑฐ
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith("module."):
new_k = k[len("module."):]
else:
new_k = k
new_state_dict[new_k] = v
# 3) HuggingFace ๊ด€๋ก€๋Œ€๋กœ ํŒŒ์ผ๋ช… ์ €์žฅ
torch.save(new_state_dict, "CSAT_RCKD.bin")
print("saved to pytorch_model.bin")