File size: 781 Bytes
9cf6c45
 
9ec0c69
 
9cf6c45
9ec0c69
 
 
 
 
 
 
9cf6c45
9ec0c69
 
9cf6c45
 
9ec0c69
 
9cf6c45
9ec0c69
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch

# 1) 원래 학습된 checkpoint 불러오기
ckpt = torch.load("./CSAT_v2_ImageNet.pth.tar", map_location="cpu", weights_only=False)

# 2) state_dict 꺼내기 (포맷에 따라 분기)
if "state_dict" in ckpt:
    state_dict = ckpt["state_dict"]
elif "model" in ckpt:
    state_dict = ckpt["model"]
else:
    state_dict = ckpt

# 3) Distributed / DataParallel인 경우 "module." 제거
new_sd = {}
for k, v in state_dict.items():
    if k.startswith("module."):
        k = k[len("module."):]
    new_sd[k] = v

# 4) HF 래퍼용으로 "backbone." prefix 붙이기
wrapped_sd = {f"backbone.{k}": v for k, v in new_sd.items()}

# 5) HF에서 사용할 파일로 저장
torch.save(wrapped_sd, "pytorch_model.bin")
print("saved HF-style weights to pytorch_model.bin")