Nit
Browse files
prune.py
CHANGED
|
@@ -4,6 +4,8 @@ import re
|
|
| 4 |
import torch
|
| 5 |
from modeling_jamba import JambaForCausalLM
|
| 6 |
|
|
|
|
|
|
|
| 7 |
model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", device_map="cpu", torch_dtype=torch.bfloat16)
|
| 8 |
|
| 9 |
def prune_and_copy_additional_layers(original_state_dict):
|
|
@@ -37,5 +39,5 @@ def prune_and_copy_additional_layers(original_state_dict):
|
|
| 37 |
pruned_state_dict = prune_and_copy_additional_layers(model.state_dict())
|
| 38 |
|
| 39 |
print("Saving weights...")
|
| 40 |
-
torch.save(pruned_state_dict,
|
| 41 |
print("Done!")
|
|
|
|
| 4 |
import torch
|
| 5 |
from modeling_jamba import JambaForCausalLM
|
| 6 |
|
| 7 |
+
output_dir = "/home/user/jamba-small"
|
| 8 |
+
|
| 9 |
model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", device_map="cpu", torch_dtype=torch.bfloat16)
|
| 10 |
|
| 11 |
def prune_and_copy_additional_layers(original_state_dict):
|
|
|
|
| 39 |
pruned_state_dict = prune_and_copy_additional_layers(model.state_dict())
|
| 40 |
|
| 41 |
print("Saving weights...")
|
| 42 |
+
torch.save(pruned_state_dict, output_dir)
|
| 43 |
print("Done!")
|