TymaaHammouda commited on
Commit
3e9dbbe
·
1 Parent(s): 07296dc

Add snapshot_download

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -1,7 +1,7 @@
1
  from fastapi import FastAPI
2
  import torch
3
  import pickle
4
- from huggingface_hub import hf_hub_download
5
  from Nested.nn.BertSeqTagger import BertSeqTagger
6
  from transformers import AutoTokenizer, AutoModel
7
  import inspect
@@ -23,10 +23,12 @@ pretrained_path = "aubmindlab/bert-base-arabertv2" # must match training
23
  tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
24
  encoder = AutoModel.from_pretrained(pretrained_path).eval()
25
 
26
- checkpoint_path = hf_hub_download(
27
- repo_id="SinaLab/Nested",
28
- filename="checkpoints/checkpoint_2.pt"
29
- )
 
 
30
  print("checkpoint_path : ", checkpoint_path)
31
 
32
  args_path = hf_hub_download(
 
1
  from fastapi import FastAPI
2
  import torch
3
  import pickle
4
+ from huggingface_hub import hf_hub_download, snapshot_download
5
  from Nested.nn.BertSeqTagger import BertSeqTagger
6
  from transformers import AutoTokenizer, AutoModel
7
  import inspect
 
23
  tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
24
  encoder = AutoModel.from_pretrained(pretrained_path).eval()
25
 
26
+ # checkpoint_path = hf_hub_download(
27
+ # repo_id="SinaLab/Nested",
28
+ # filename="checkpoints/checkpoint_2.pt"
29
+ # )
30
+
31
+ checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")
32
  print("checkpoint_path : ", checkpoint_path)
33
 
34
  args_path = hf_hub_download(