sosigikiller commited on
Commit
9ec0c69
·
1 Parent(s): 36340c3

initial push

Browse files
Files changed (3) hide show
  1. convert_and_push.py +0 -0
  2. pytorch_model.bin +2 -2
  3. tar2bin.py +19 -16
convert_and_push.py ADDED
File without changes
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:104253a41517faf75704853c207082c20b31264a7e9566a9a1ca22a9e088a729
3
- size 44535575
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3456c4d1c386b6fc57670f63ae1b8cfaccc3157528c4db99ed8fb8bf048959d4
3
+ size 44539159
tar2bin.py CHANGED
@@ -1,23 +1,26 @@
1
  import torch
2
- from collections import OrderedDict
3
 
 
 
4
 
5
- ckpt_path = "./CSAT_RCKD.pth.tar"
6
- ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
 
 
 
 
 
7
 
8
- # 1) state_dict 꺼내기
9
- # - 보통 {'state_dict': ...} 형태니까 먼저 이걸 시도하고,
10
- # - 아니면 그냥 ckpt 전체가 state_dict인 경우도 있어서 fallback
11
- state_dict = ckpt.get("state_dict", ckpt)
12
- # 2) DataParallel 썼으면 key 앞에 'module.' 붙어있을 수 있어서 제거
13
- new_state_dict = OrderedDict()
14
  for k, v in state_dict.items():
15
  if k.startswith("module."):
16
- new_k = k[len("module."):]
17
- else:
18
- new_k = k
19
- new_state_dict[new_k] = v
20
 
21
- # 3) HuggingFace 관례대파일명 저장
22
- torch.save(new_state_dict, "CSAT_RCKD.bin")
23
- print("saved to pytorch_model.bin")
 
 
 
 
1
  import torch
 
2
 
3
+ # 1) 원래 학습된 checkpoint 불러오기
4
+ ckpt = torch.load("./CSAT_v2_ImageNet.pth.tar", map_location="cpu", weights_only=False)
5
 
6
+ # 2) state_dict 꺼내기 (포맷에 따라 분기)
7
+ if "state_dict" in ckpt:
8
+ state_dict = ckpt["state_dict"]
9
+ elif "model" in ckpt:
10
+ state_dict = ckpt["model"]
11
+ else:
12
+ state_dict = ckpt
13
 
14
+ # 3) Distributed / DataParallel인 경우 "module." 제거
15
+ new_sd = {}
 
 
 
 
16
  for k, v in state_dict.items():
17
  if k.startswith("module."):
18
+ k = k[len("module."):]
19
+ new_sd[k] = v
 
 
20
 
21
+ # 4) HF 래퍼용으"backbone." prefix 붙이기
22
+ wrapped_sd = {f"backbone.{k}": v for k, v in new_sd.items()}
23
+
24
+ # 5) HF에서 사용할 파일로 저장
25
+ torch.save(wrapped_sd, "pytorch_model.bin")
26
+ print("saved HF-style weights to pytorch_model.bin")