CSATv2 / tar2bin.py
sosigikiller's picture
initial push
9ec0c69
raw
history blame contribute delete
781 Bytes
import torch
# 1) 원래 학습된 checkpoint 불러오기
ckpt = torch.load("./CSAT_v2_ImageNet.pth.tar", map_location="cpu", weights_only=False)
# 2) state_dict 꺼내기 (포맷에 따라 분기)
if "state_dict" in ckpt:
state_dict = ckpt["state_dict"]
elif "model" in ckpt:
state_dict = ckpt["model"]
else:
state_dict = ckpt
# 3) Distributed / DataParallel인 경우 "module." 제거
new_sd = {}
for k, v in state_dict.items():
if k.startswith("module."):
k = k[len("module."):]
new_sd[k] = v
# 4) HF 래퍼용으로 "backbone." prefix 붙이기
wrapped_sd = {f"backbone.{k}": v for k, v in new_sd.items()}
# 5) HF에서 사용할 파일로 저장
torch.save(wrapped_sd, "pytorch_model.bin")
print("saved HF-style weights to pytorch_model.bin")