Update app_flash1.py
Browse files- app_flash1.py +1 -1
app_flash1.py
CHANGED
|
@@ -63,7 +63,7 @@ def build_encoder(model_name="gpt2", max_length=128):
|
|
| 63 |
def push_flashpack_model_to_hf(model, hf_repo):
|
| 64 |
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 65 |
repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
|
| 66 |
-
model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"))
|
| 67 |
with open(os.path.join(tmp_dir, "README.md"), "w") as f:
|
| 68 |
f.write("# FlashPack Model\nTrained locally and pushed to HF.")
|
| 69 |
repo.push_to_hub()
|
|
|
|
| 63 |
def push_flashpack_model_to_hf(model, hf_repo):
|
| 64 |
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 65 |
repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
|
| 66 |
+
model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"),target_dtype=torch.float32)
|
| 67 |
with open(os.path.join(tmp_dir, "README.md"), "w") as f:
|
| 68 |
f.write("# FlashPack Model\nTrained locally and pushed to HF.")
|
| 69 |
repo.push_to_hub()
|