Upload folder using huggingface_hub
Browse files- checkpoints/checkpoint-100.pt +2 -2
- checkpoints/checkpoint-25.pt +2 -2
- checkpoints/checkpoint-50.pt +2 -2
- checkpoints/checkpoint-75.pt +2 -2
- convert_checkpoints.py +21 -10
checkpoints/checkpoint-100.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3fafa99fb361dae2d224a1307eb27fd7362cfb29144c5b369bf9dae370563080
|
| 3 |
+
size 2478085
|
checkpoints/checkpoint-25.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:200f2670bbfd05c752c6af8e5e3835b47ff11e4c7286203f0bbfaa0c6fbead11
|
| 3 |
+
size 2478049
|
checkpoints/checkpoint-50.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:beb0e997530cf6327992063dab60027b1b74c57fcb2cfaefdfdbabcd7d1068af
|
| 3 |
+
size 2478049
|
checkpoints/checkpoint-75.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f4a79008dd26635ed3276e069048391ff2962fa834d33e3649c9d2653fe5893a
|
| 3 |
+
size 2478049
|
convert_checkpoints.py
CHANGED
|
@@ -10,16 +10,27 @@ def convert_checkpoint_format(checkpoint_path):
|
|
| 10 |
# Load the checkpoint
|
| 11 |
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def convert_all_checkpoints():
|
| 25 |
"""Convert all checkpoint files in the current directory."""
|
|
|
|
| 10 |
# Load the checkpoint
|
| 11 |
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 12 |
|
| 13 |
+
|
| 14 |
+
model_state_dict = checkpoint
|
| 15 |
+
print(f"Converting {checkpoint_path}: nested -> direct format")
|
| 16 |
+
|
| 17 |
+
# Create new state dict with flattened keys
|
| 18 |
+
new_state_dict = {}
|
| 19 |
+
for key, value in model_state_dict.items():
|
| 20 |
+
if key.startswith('model.'):
|
| 21 |
+
# Remove 'model.' prefix
|
| 22 |
+
new_key = key[6:] # Remove 'model.' (6 characters)
|
| 23 |
+
print(f"Converting key: '{key}' -> '{new_key}'")
|
| 24 |
+
new_state_dict[new_key] = value
|
| 25 |
+
else:
|
| 26 |
+
# Keep keys that don't start with 'model.'
|
| 27 |
+
print(f"Keeping key: '{key}'")
|
| 28 |
+
new_state_dict[key] = value
|
| 29 |
+
|
| 30 |
+
# Save back in direct format
|
| 31 |
+
torch.save(new_state_dict, checkpoint_path)
|
| 32 |
+
print(f"Updated: {checkpoint_path}")
|
| 33 |
+
|
| 34 |
|
| 35 |
def convert_all_checkpoints():
|
| 36 |
"""Convert all checkpoint files in the current directory."""
|