DelaliScratchwerk commited on
Commit
e8485c8
Β·
verified Β·
1 Parent(s): 5e572ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -20
app.py CHANGED
@@ -1,42 +1,79 @@
1
  import json, numpy as np, gradio as gr
2
  from setfit import SetFitModel
3
  from huggingface_hub import hf_hub_download
 
4
 
5
  MODEL_ID = "DelaliScratchwerk/text-period-setfit"
6
 
7
- # Try to load labels from the model repo; if missing, use local labels.json
8
  try:
9
  labels_path = hf_hub_download(MODEL_ID, "labels.json")
10
  LABELS = json.load(open(labels_path))
11
  except Exception:
12
  LABELS = json.load(open("labels.json"))
13
 
 
 
 
 
 
14
  model = SetFitModel.from_pretrained(MODEL_ID)
15
 
 
 
 
 
 
 
 
 
 
16
  def predict(txt: str):
17
  if not txt.strip():
18
- return "β€”", {"error": "Please paste some text."}
19
  probs = np.asarray(model.predict_proba([txt])[0], dtype=float).ravel()
 
20
  if probs.size != len(LABELS):
21
- return "β€”", {"error": f"label mismatch: model has {probs.size}, labels.json has {len(LABELS)}"}
 
22
  order = np.argsort(probs)[::-1]
23
- top = LABELS[int(order[0])]
24
- return top, {LABELS[int(i)]: float(probs[int(i)]) for i in order}
25
-
26
- demo = gr.Interface(
27
- fn=predict,
28
- inputs=gr.Textbox(lines=8, label="Paste text"),
29
- outputs=[gr.Label(label="Predicted Period"), gr.JSON(label="Scores")],
30
- title="Text β†’ Time Period (SetFit)",
31
- examples=[
32
- "Schools went remote during the pandemic; everyone wore N95s and used Zoom.",
33
- "Sputnik launched and kicked off the space race.",
34
- "MySpace was the most popular social network for a while.",
35
- "TikTok creators exploded in popularity.",
36
- ],
37
- cache_examples=False,
38
- allow_flagging="never",
39
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  if __name__ == "__main__":
42
  demo.launch()
 
1
  import json, numpy as np, gradio as gr
2
  from setfit import SetFitModel
3
  from huggingface_hub import hf_hub_download
4
+ from evidence import extract_evidence
5
 
6
  MODEL_ID = "DelaliScratchwerk/text-period-setfit"
7
 
8
+ # Load labels: try from model repo, else local labels.json in the Space
9
  try:
10
  labels_path = hf_hub_download(MODEL_ID, "labels.json")
11
  LABELS = json.load(open(labels_path))
12
  except Exception:
13
  LABELS = json.load(open("labels.json"))
14
 
15
+ # thresholds – tweak later with validation
16
+ TOP_K = 3
17
+ UNCERTAINTY_THRESHOLD = 0.42 # if top1 prob below this β†’ "uncertain"
18
+ MARGIN_THRESHOLD = 0.08 # or if (top1 - top2) < this β†’ "uncertain"
19
+
20
  model = SetFitModel.from_pretrained(MODEL_ID)
21
 
22
+ def format_evidence(ev):
23
+ parts = []
24
+ if ev.get("years"):
25
+ parts.append("**Years found:** " + ", ".join(ev["years"]))
26
+ if ev.get("keyword_hits"):
27
+ for b, ks in ev["keyword_hits"].items():
28
+ parts.append(f"**{b}:** " + ", ".join(ks))
29
+ return "\n\n".join(parts) if parts else "_No explicit time clues found._"
30
+
31
  def predict(txt: str):
32
  if not txt.strip():
33
+ return "β€”", "Paste some text.", {}
34
  probs = np.asarray(model.predict_proba([txt])[0], dtype=float).ravel()
35
+
36
  if probs.size != len(LABELS):
37
+ return "β€”", f"Label mismatch: model has {probs.size} classes, labels.json has {len(LABELS)}", {}
38
+
39
  order = np.argsort(probs)[::-1]
40
+ top1, top2 = probs[order[0]], probs[order[1]] if probs.size > 1 else 0.0
41
+ ev = extract_evidence(txt)
42
+
43
+ # uncertain mode
44
+ if top1 < UNCERTAINTY_THRESHOLD or (top1 - top2) < MARGIN_THRESHOLD:
45
+ topk = [{ "label": LABELS[i], "score": float(probs[i]) } for i in order[:TOP_K]]
46
+ md = "**Uncertain** β€” here are the top candidates:\n" + "\n".join(
47
+ [f"- **{d['label']}**: {d['score']:.3f}" for d in topk]
48
+ )
49
+ return "uncertain", md + "\n\n" + format_evidence(ev), {LABELS[i]: float(probs[i]) for i in order}
50
+
51
+ # confident
52
+ best = LABELS[order[0]]
53
+ md = f"**Reasoning hints**\n\n" + format_evidence(ev)
54
+ return best, md, {LABELS[i]: float(probs[i]) for i in order}
55
+
56
+ with gr.Blocks(title="Text β†’ Time Period (SetFit)") as demo:
57
+ gr.Markdown("# Text β†’ Time Period (SetFit)")
58
+ with gr.Row():
59
+ text = gr.Textbox(lines=8, label="Paste text")
60
+ with gr.Column():
61
+ pred = gr.Label(label="Predicted")
62
+ reason = gr.Markdown(label="Evidence")
63
+ scores = gr.JSON(label="Scores")
64
+
65
+ btn = gr.Button("Submit", variant="primary")
66
+ btn.click(predict, inputs=text, outputs=[pred, reason, scores])
67
+
68
+ gr.Examples(
69
+ examples=[
70
+ "Schools went remote during the pandemic; everyone wore N95s and used Zoom.",
71
+ "Sputnik launched and kicked off the space race.",
72
+ "MySpace was the most popular social network for a while.",
73
+ "Creators blew up on TikTok; companies rolled out ChatGPT-powered tools.",
74
+ ],
75
+ inputs=text
76
+ )
77
 
78
  if __name__ == "__main__":
79
  demo.launch()