rahul7star commited on
Commit
4c22dd6
·
verified ·
1 Parent(s): 5ee9a29

Update app_flash1.py

Browse files
Files changed (1) hide show
  1. app_flash1.py +1 -1
app_flash1.py CHANGED
@@ -63,7 +63,7 @@ def build_encoder(model_name="gpt2", max_length=128):
63
  def push_flashpack_model_to_hf(model, hf_repo):
64
  with tempfile.TemporaryDirectory() as tmp_dir:
65
  repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
66
- model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"))
67
  with open(os.path.join(tmp_dir, "README.md"), "w") as f:
68
  f.write("# FlashPack Model\nTrained locally and pushed to HF.")
69
  repo.push_to_hub()
 
63
  def push_flashpack_model_to_hf(model, hf_repo):
64
  with tempfile.TemporaryDirectory() as tmp_dir:
65
  repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
66
+ model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"),target_dtype=torch.float32)
67
  with open(os.path.join(tmp_dir, "README.md"), "w") as f:
68
  f.write("# FlashPack Model\nTrained locally and pushed to HF.")
69
  repo.push_to_hub()