import torch # 1) 원래 학습된 checkpoint 불러오기 ckpt = torch.load("./CSAT_v2_ImageNet.pth.tar", map_location="cpu", weights_only=False) # 2) state_dict 꺼내기 (포맷에 따라 분기) if "state_dict" in ckpt: state_dict = ckpt["state_dict"] elif "model" in ckpt: state_dict = ckpt["model"] else: state_dict = ckpt # 3) Distributed / DataParallel인 경우 "module." 제거 new_sd = {} for k, v in state_dict.items(): if k.startswith("module."): k = k[len("module."):] new_sd[k] = v # 4) HF 래퍼용으로 "backbone." prefix 붙이기 wrapped_sd = {f"backbone.{k}": v for k, v in new_sd.items()} # 5) HF에서 사용할 파일로 저장 torch.save(wrapped_sd, "pytorch_model.bin") print("saved HF-style weights to pytorch_model.bin")