algo2217 commited on
Commit
f73efa3
·
verified ·
1 Parent(s): 60df24b

Upload folder using huggingface_hub

Browse files
checkpoints/checkpoint-100.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bdddf11f6c354f0fe93e9e0304366e3c753926157bc0ef8d6b726d850adcd139
3
- size 2480362
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fafa99fb361dae2d224a1307eb27fd7362cfb29144c5b369bf9dae370563080
3
+ size 2478085
checkpoints/checkpoint-25.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b47a7288186dd7e81fc0b4dff55ea628b4305471cf642159bb5fbefdfdc6c927
3
- size 2480328
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:200f2670bbfd05c752c6af8e5e3835b47ff11e4c7286203f0bbfaa0c6fbead11
3
+ size 2478049
checkpoints/checkpoint-50.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:404c5ce1e19ed45746ed0d35635e4f1081dd1917fb12cae1b1a63ca7208bd0c7
3
- size 2480328
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:beb0e997530cf6327992063dab60027b1b74c57fcb2cfaefdfdbabcd7d1068af
3
+ size 2478049
checkpoints/checkpoint-75.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3beb8688ec66697e58f601b78412a8d38a05ff7f8fbf8cb0f9e91a2e2a51989e
3
- size 2480328
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4a79008dd26635ed3276e069048391ff2962fa834d33e3649c9d2653fe5893a
3
+ size 2478049
convert_checkpoints.py CHANGED
@@ -10,16 +10,27 @@ def convert_checkpoint_format(checkpoint_path):
10
  # Load the checkpoint
11
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
12
 
13
- # Extract just the model state dict
14
- if 'model' in checkpoint:
15
- model_state_dict = checkpoint['model']
16
- print(f"Converting {checkpoint_path}: nested -> direct format")
17
-
18
- # Save back in direct format
19
- torch.save(model_state_dict, checkpoint_path)
20
- print(f"Updated: {checkpoint_path}")
21
- else:
22
- print(f"Already in direct format: {checkpoint_path}")
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  def convert_all_checkpoints():
25
  """Convert all checkpoint files in the current directory."""
 
10
  # Load the checkpoint
11
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
12
 
13
+
14
+ model_state_dict = checkpoint
15
+ print(f"Converting {checkpoint_path}: nested -> direct format")
16
+
17
+ # Create new state dict with flattened keys
18
+ new_state_dict = {}
19
+ for key, value in model_state_dict.items():
20
+ if key.startswith('model.'):
21
+ # Remove 'model.' prefix
22
+ new_key = key[6:] # Remove 'model.' (6 characters)
23
+ print(f"Converting key: '{key}' -> '{new_key}'")
24
+ new_state_dict[new_key] = value
25
+ else:
26
+ # Keep keys that don't start with 'model.'
27
+ print(f"Keeping key: '{key}'")
28
+ new_state_dict[key] = value
29
+
30
+ # Save back in direct format
31
+ torch.save(new_state_dict, checkpoint_path)
32
+ print(f"Updated: {checkpoint_path}")
33
+
34
 
35
  def convert_all_checkpoints():
36
  """Convert all checkpoint files in the current directory."""