AMR-KELEG commited on
Commit
465af14
·
1 Parent(s): 24cf6c5

Store the predictions

Browse files
Files changed (2) hide show
  1. app.py +5 -1
  2. utils.py +28 -0
app.py CHANGED
@@ -9,6 +9,7 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer
9
  from constants import DIALECTS_WITH_LABELS
10
  from inspect import getmembers, isfunction
11
  import eval_utils
 
12
  import numpy as np
13
  from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
14
 
@@ -42,7 +43,10 @@ with tab2:
42
  for sentence in tqdm(sentences)
43
  ]
44
 
45
- # TODO: Store the predictions in a private dataset
 
 
 
46
 
47
  # Evaluate the model
48
  accuracy_scores = {}
 
9
  from constants import DIALECTS_WITH_LABELS
10
  from inspect import getmembers, isfunction
11
  import eval_utils
12
+ import utils
13
  import numpy as np
14
  from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
15
 
 
43
  for sentence in tqdm(sentences)
44
  ]
45
 
46
+ # Store the predictions in a private dataset
47
+ utils.upload_predictions(
48
+ os.environ["PREDICTIONS_DATASET_NAME"], predictions, model_name
49
+ )
50
 
51
  # Evaluate the model
52
  accuracy_scores = {}
utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import time
4
+ from huggingface_hub import HfApi
5
+
6
+
7
+ def current_seconds_time():
8
+ return round(time.time())
9
+
10
+
11
+ def upload_predictions(repo_id, predictions, model_name):
12
+ api = HfApi()
13
+
14
+ predictions_filename = (
15
+ f"predictions_{current_seconds_time()}_{re.sub('/', '_', model_name)}.json"
16
+ )
17
+ predictions_object = {"model_name": model_name, "predictions": predictions}
18
+
19
+ with open(predictions_filename, "w") as f:
20
+ json.dump(predictions_object, f)
21
+
22
+ future = api.upload_file(
23
+ path_or_fileobj=predictions_filename,
24
+ path_in_repo=predictions_filename,
25
+ repo_id=repo_id,
26
+ repo_type="dataset",
27
+ run_as_future=True,
28
+ )