| import torch | |
| # 1) 원래 학습된 checkpoint 불러오기 | |
| ckpt = torch.load("./CSAT_v2_ImageNet.pth.tar", map_location="cpu", weights_only=False) | |
| # 2) state_dict 꺼내기 (포맷에 따라 분기) | |
| if "state_dict" in ckpt: | |
| state_dict = ckpt["state_dict"] | |
| elif "model" in ckpt: | |
| state_dict = ckpt["model"] | |
| else: | |
| state_dict = ckpt | |
| # 3) Distributed / DataParallel인 경우 "module." 제거 | |
| new_sd = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith("module."): | |
| k = k[len("module."):] | |
| new_sd[k] = v | |
| # 4) HF 래퍼용으로 "backbone." prefix 붙이기 | |
| wrapped_sd = {f"backbone.{k}": v for k, v in new_sd.items()} | |
| # 5) HF에서 사용할 파일로 저장 | |
| torch.save(wrapped_sd, "pytorch_model.bin") | |
| print("saved HF-style weights to pytorch_model.bin") |