| import torch | |
| from collections import OrderedDict | |
| ckpt_path = "./CSAT_RCKD.pth.tar" | |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| # 1) state_dict ๊บผ๋ด๊ธฐ | |
| # - ๋ณดํต {'state_dict': ...} ํํ๋๊น ๋จผ์ ์ด๊ฑธ ์๋ํ๊ณ , | |
| # - ์๋๋ฉด ๊ทธ๋ฅ ckpt ์ ์ฒด๊ฐ state_dict์ธ ๊ฒฝ์ฐ๋ ์์ด์ fallback | |
| state_dict = ckpt.get("state_dict", ckpt) | |
| # 2) DataParallel ์ผ์ผ๋ฉด key ์์ 'module.' ๋ถ์ด์์ ์ ์์ด์ ์ ๊ฑฐ | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| if k.startswith("module."): | |
| new_k = k[len("module."):] | |
| else: | |
| new_k = k | |
| new_state_dict[new_k] = v | |
| # 3) HuggingFace ๊ด๋ก๋๋ก ํ์ผ๋ช ์ ์ฅ | |
| torch.save(new_state_dict, "CSAT_RCKD.bin") | |
| print("saved to pytorch_model.bin") |