timagonch commited on
Commit
8cf9caf
Β·
1 Parent(s): 20a4b09

Log predictions to HF dataset repo via CommitScheduler

Browse files
Files changed (1) hide show
  1. app.py +43 -1
app.py CHANGED
@@ -3,6 +3,7 @@ app.py β€” Algospeak Classifier demo
3
 
4
  Streamlit UI for the dual BERTweet model.
5
  Type a social media post and see the predicted class + confidence scores.
 
6
  """
7
 
8
  import sys
@@ -10,18 +11,24 @@ from pathlib import Path
10
 
11
  sys.path.insert(0, str(Path(__file__).parent / "poc" / "src"))
12
 
 
13
  import yaml
14
  import torch
15
  import numpy as np
16
  import emoji
17
  import streamlit as st
 
18
  from transformers import AutoTokenizer
19
- from huggingface_hub import hf_hub_download
20
 
21
  from inference import load_unsupervised_encoder, classify_text
22
 
23
  BASE_DIR = Path(__file__).parent
24
  MODEL_REPO = "timagonch/algospeak-classifier-model"
 
 
 
 
25
 
26
  CLASS_COLORS = {
27
  "Allowed": "green",
@@ -46,6 +53,39 @@ def load_model():
46
  return encoder, prototypes, tokenizer, cfg, device
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # ─────────────────────────────────────────────────────────────────────
50
  # UI
51
  # ─────────────────────────────────────────────────────────────────────
@@ -68,3 +108,5 @@ if st.button("Classify", type="primary") and text.strip():
68
  st.write("**Confidence scores:**")
69
  for name, score in sorted(result["scores"].items(), key=lambda x: -x[1]):
70
  st.progress(float(score), text=f"{name}: {score:.1%}")
 
 
 
3
 
4
  Streamlit UI for the dual BERTweet model.
5
  Type a social media post and see the predicted class + confidence scores.
6
+ Predictions are logged to a private HF dataset repo via CommitScheduler.
7
  """
8
 
9
  import sys
 
11
 
12
  sys.path.insert(0, str(Path(__file__).parent / "poc" / "src"))
13
 
14
+ import csv
15
  import yaml
16
  import torch
17
  import numpy as np
18
  import emoji
19
  import streamlit as st
20
+ from datetime import datetime
21
  from transformers import AutoTokenizer
22
+ from huggingface_hub import hf_hub_download, CommitScheduler
23
 
24
  from inference import load_unsupervised_encoder, classify_text
25
 
26
  BASE_DIR = Path(__file__).parent
27
  MODEL_REPO = "timagonch/algospeak-classifier-model"
28
+ LOG_REPO = "timagonch/algospeak-logs"
29
+ LOG_DIR = BASE_DIR / "logs"
30
+ LOG_FILE = LOG_DIR / "predictions.csv"
31
+ LOG_COLS = ["text", "predicted_label", "score_allowed", "score_offensive", "score_mature", "score_algospeak", "timestamp"]
32
 
33
  CLASS_COLORS = {
34
  "Allowed": "green",
 
53
  return encoder, prototypes, tokenizer, cfg, device
54
 
55
 
56
+ @st.cache_resource
57
+ def get_scheduler():
58
+ LOG_DIR.mkdir(exist_ok=True)
59
+ return CommitScheduler(
60
+ repo_id=LOG_REPO,
61
+ repo_type="dataset",
62
+ folder_path=LOG_DIR,
63
+ path_in_repo="logs",
64
+ every=5,
65
+ )
66
+
67
+
68
+ def log_prediction(text, result):
69
+ scheduler = get_scheduler()
70
+ scores = result["scores"]
71
+ row = {
72
+ "text": text,
73
+ "predicted_label": result["predicted_label"],
74
+ "score_allowed": round(scores["Allowed"], 4),
75
+ "score_offensive": round(scores["Offensive Language"], 4),
76
+ "score_mature": round(scores["Mature Content"], 4),
77
+ "score_algospeak": round(scores["Algospeak"], 4),
78
+ "timestamp": datetime.utcnow().isoformat(),
79
+ }
80
+ with scheduler.lock:
81
+ write_header = not LOG_FILE.exists()
82
+ with open(LOG_FILE, "a", newline="", encoding="utf-8") as f:
83
+ writer = csv.DictWriter(f, fieldnames=LOG_COLS)
84
+ if write_header:
85
+ writer.writeheader()
86
+ writer.writerow(row)
87
+
88
+
89
  # ─────────────────────────────────────────────────────────────────────
90
  # UI
91
  # ─────────────────────────────────────────────────────────────────────
 
108
  st.write("**Confidence scores:**")
109
  for name, score in sorted(result["scores"].items(), key=lambda x: -x[1]):
110
  st.progress(float(score), text=f"{name}: {score:.1%}")
111
+
112
+ log_prediction(text, result)