sosigikiller commited on
Commit
9ec0c69
ยท
1 Parent(s): 36340c3

initial push

Browse files
Files changed (3) hide show
  1. convert_and_push.py +0 -0
  2. pytorch_model.bin +2 -2
  3. tar2bin.py +19 -16
convert_and_push.py ADDED
File without changes
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:104253a41517faf75704853c207082c20b31264a7e9566a9a1ca22a9e088a729
3
- size 44535575
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3456c4d1c386b6fc57670f63ae1b8cfaccc3157528c4db99ed8fb8bf048959d4
3
+ size 44539159
tar2bin.py CHANGED
@@ -1,23 +1,26 @@
1
  import torch
2
- from collections import OrderedDict
3
 
 
 
4
 
5
- ckpt_path = "./CSAT_RCKD.pth.tar"
6
- ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
 
 
 
 
 
7
 
8
- # 1) state_dict ๊บผ๋‚ด๊ธฐ
9
- # - ๋ณดํ†ต {'state_dict': ...} ํ˜•ํƒœ๋‹ˆ๊นŒ ๋จผ์ € ์ด๊ฑธ ์‹œ๋„ํ•˜๊ณ ,
10
- # - ์•„๋‹ˆ๋ฉด ๊ทธ๋ƒฅ ckpt ์ „์ฒด๊ฐ€ state_dict์ธ ๊ฒฝ์šฐ๋„ ์žˆ์–ด์„œ fallback
11
- state_dict = ckpt.get("state_dict", ckpt)
12
- # 2) DataParallel ์ผ์œผ๋ฉด key ์•ž์— 'module.' ๋ถ™์–ด์žˆ์„ ์ˆ˜ ์žˆ์–ด์„œ ์ œ๊ฑฐ
13
- new_state_dict = OrderedDict()
14
  for k, v in state_dict.items():
15
  if k.startswith("module."):
16
- new_k = k[len("module."):]
17
- else:
18
- new_k = k
19
- new_state_dict[new_k] = v
20
 
21
- # 3) HuggingFace ๊ด€๋ก€๋Œ€๋กœ ํŒŒ์ผ๋ช… ์ €์žฅ
22
- torch.save(new_state_dict, "CSAT_RCKD.bin")
23
- print("saved to pytorch_model.bin")
 
 
 
 
1
  import torch
 
2
 
3
+ # 1) ์›๋ž˜ ํ•™์Šต๋œ checkpoint ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
4
+ ckpt = torch.load("./CSAT_v2_ImageNet.pth.tar", map_location="cpu", weights_only=False)
5
 
6
+ # 2) state_dict ๊บผ๋‚ด๊ธฐ (ํฌ๋งท์— ๋”ฐ๋ผ ๋ถ„๊ธฐ)
7
+ if "state_dict" in ckpt:
8
+ state_dict = ckpt["state_dict"]
9
+ elif "model" in ckpt:
10
+ state_dict = ckpt["model"]
11
+ else:
12
+ state_dict = ckpt
13
 
14
+ # 3) Distributed / DataParallel์ธ ๊ฒฝ์šฐ "module." ์ œ๊ฑฐ
15
+ new_sd = {}
 
 
 
 
16
  for k, v in state_dict.items():
17
  if k.startswith("module."):
18
+ k = k[len("module."):]
19
+ new_sd[k] = v
 
 
20
 
21
+ # 4) HF ๋ž˜ํผ์šฉ์œผ๋กœ "backbone." prefix ๋ถ™์ด๊ธฐ
22
+ wrapped_sd = {f"backbone.{k}": v for k, v in new_sd.items()}
23
+
24
+ # 5) HF์—์„œ ์‚ฌ์šฉํ•  ํŒŒ์ผ๋กœ ์ €์žฅ
25
+ torch.save(wrapped_sd, "pytorch_model.bin")
26
+ print("saved HF-style weights to pytorch_model.bin")