DelaliScratchwerk commited on
Commit
d6c5a54
Β·
verified Β·
1 Parent(s): 0a06046

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -15
app.py CHANGED
@@ -1,21 +1,23 @@
1
- import json, numpy as np, gradio as gr
2
  from setfit import SetFitModel
3
  from huggingface_hub import hf_hub_download
4
  from evidence import extract_evidence
5
- import shutil, os, pathlib
6
- CACHE_DIR = os.path.expanduser("~/.cache/huggingface")
7
- shutil.rmtree(CACHE_DIR, ignore_errors=True) # nuke old cached models
8
- pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
9
 
 
 
 
 
 
10
 
11
  MODEL_ID = "DelaliScratchwerk/text-period-setfit"
12
 
13
- # ---- thresholds (use your tuned values)
14
  TOP_K = 3
15
- UNCERTAINTY_THRESHOLD = 0.516 # from tune_thresholds.py
16
- MARGIN_THRESHOLD = 0.387 # from tune_thresholds.py
17
 
18
- # ---- load labels (Hub -> local fallback)
19
  try:
20
  labels_path = hf_hub_download(MODEL_ID, "labels.json")
21
  LABELS = json.load(open(labels_path))
@@ -24,19 +26,24 @@ except Exception:
24
 
25
  model = SetFitModel.from_pretrained(MODEL_ID)
26
 
27
- def format_evidence(ev):
28
  parts = []
29
  if ev.get("years"):
30
  parts.append("**Years found:** " + ", ".join(ev["years"]))
31
  if ev.get("keyword_hits"):
32
- for b, ks in ev["keyword_hits"].items():
33
- parts.append(f"**{b}:** " + ", ".join(ks))
 
34
  return "\n\n".join(parts) if parts else "_No explicit time clues found._"
35
 
36
  def predict(txt: str):
37
- if not txt.strip():
 
38
  return "β€”", "Paste some text.", {}
 
39
  probs = np.asarray(model.predict_proba([txt])[0], dtype=float).ravel()
 
 
40
 
41
  if probs.size != len(LABELS):
42
  return "β€”", f"Label mismatch: model has {probs.size} classes, labels.json has {len(LABELS)}", {}
@@ -50,7 +57,7 @@ def predict(txt: str):
50
  if top1 < UNCERTAINTY_THRESHOLD or (top1 - top2) < MARGIN_THRESHOLD:
51
  topk = [{"label": LABELS[i], "score": float(probs[i])} for i in order[:TOP_K]]
52
  md = "**Uncertain** β€” top candidates:\n" + "\n".join(
53
- [f"- **{d['label']}**: {d['score']:.3f}" for d in topk]
54
  )
55
  return "uncertain", md + "\n\n" + format_evidence(ev), {LABELS[i]: float(probs[i]) for i in order}
56
 
@@ -69,7 +76,8 @@ with gr.Blocks(title="Text β†’ Time Period (SetFit)") as demo:
69
  scores = gr.JSON(label="Scores")
70
 
71
  btn = gr.Button("Submit", variant="primary")
72
- btn.click(predict, inputs=text, outputs=[pred, reason, scores])
 
73
 
74
  gr.Examples(
75
  examples=[
 
1
+ import json, numpy as np, gradio as gr
2
  from setfit import SetFitModel
3
  from huggingface_hub import hf_hub_download
4
  from evidence import extract_evidence
5
+ import os, shutil, pathlib
 
 
 
6
 
7
+ # Optional: only clear cache if you set CLEAR_HF_CACHE=1 in the Space secrets
8
+ if os.getenv("CLEAR_HF_CACHE") == "1":
9
+ CACHE_DIR = os.path.expanduser("~/.cache/huggingface")
10
+ shutil.rmtree(CACHE_DIR, ignore_errors=True)
11
+ pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
12
 
13
  MODEL_ID = "DelaliScratchwerk/text-period-setfit"
14
 
15
+ # thresholds (your tuned values)
16
  TOP_K = 3
17
+ UNCERTAINTY_THRESHOLD = 0.516
18
+ MARGIN_THRESHOLD = 0.387
19
 
20
+ # labels (Hub -> local fallback)
21
  try:
22
  labels_path = hf_hub_download(MODEL_ID, "labels.json")
23
  LABELS = json.load(open(labels_path))
 
26
 
27
  model = SetFitModel.from_pretrained(MODEL_ID)
28
 
29
+ def format_evidence(ev: dict) -> str:
30
  parts = []
31
  if ev.get("years"):
32
  parts.append("**Years found:** " + ", ".join(ev["years"]))
33
  if ev.get("keyword_hits"):
34
+ for bucket, keys in ev["keyword_hits"].items():
35
+ if keys:
36
+ parts.append(f"**{bucket}:** " + ", ".join(keys))
37
  return "\n\n".join(parts) if parts else "_No explicit time clues found._"
38
 
39
  def predict(txt: str):
40
+ txt = (txt or "").strip()
41
+ if not txt:
42
  return "β€”", "Paste some text.", {}
43
+
44
  probs = np.asarray(model.predict_proba([txt])[0], dtype=float).ravel()
45
+ if probs.size == 0:
46
+ return "β€”", "Model returned no probabilities.", {}
47
 
48
  if probs.size != len(LABELS):
49
  return "β€”", f"Label mismatch: model has {probs.size} classes, labels.json has {len(LABELS)}", {}
 
57
  if top1 < UNCERTAINTY_THRESHOLD or (top1 - top2) < MARGIN_THRESHOLD:
58
  topk = [{"label": LABELS[i], "score": float(probs[i])} for i in order[:TOP_K]]
59
  md = "**Uncertain** β€” top candidates:\n" + "\n".join(
60
+ f"- **{d['label']}**: {d['score']:.3f}" for d in topk
61
  )
62
  return "uncertain", md + "\n\n" + format_evidence(ev), {LABELS[i]: float(probs[i]) for i in order}
63
 
 
76
  scores = gr.JSON(label="Scores")
77
 
78
  btn = gr.Button("Submit", variant="primary")
79
+ # πŸ‘‡ Explicit, stable API route (your Space docs will show /api/predict)
80
+ btn.click(predict, inputs=text, outputs=[pred, reason, scores], api_name="/predict")
81
 
82
  gr.Examples(
83
  examples=[