File size: 3,449 Bytes
ac77253 6a0dda2 3133f26 e8485c8 d6c5a54 0a06046 d6c5a54 a3ce5e8 d5706a4 f327d31 d6c5a54 980f96f d6c5a54 980f96f d6c5a54 f327d31 3133f26 f327d31 1c53ee5 d6c5a54 e8485c8 d6c5a54 e8485c8 1c53ee5 d6c5a54 e8485c8 d6c5a54 d5706a4 d6c5a54 e8485c8 d5706a4 e8485c8 d5706a4 6cbec58 e8485c8 6cbec58 d6c5a54 e8485c8 6cbec58 e8485c8 8ac19c7 e8485c8 3133f26 a3ce5e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import json, numpy as np, gradio as gr
from setfit import SetFitModel
from huggingface_hub import hf_hub_download
from evidence import extract_evidence
import os, shutil, pathlib
# Optional: only clear cache if you set CLEAR_HF_CACHE=1 in the Space secrets
if os.getenv("CLEAR_HF_CACHE") == "1":
CACHE_DIR = os.path.expanduser("~/.cache/huggingface")
shutil.rmtree(CACHE_DIR, ignore_errors=True)
pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
MODEL_ID = "DelaliScratchwerk/text-period-setfit"
# thresholds (your tuned values)
TOP_K = 3
UNCERTAINTY_THRESHOLD = 0.516
MARGIN_THRESHOLD = 0.387
# labels (Hub -> local fallback)
try:
labels_path = hf_hub_download(MODEL_ID, "labels.json")
LABELS = json.load(open(labels_path))
except Exception:
LABELS = json.load(open("labels.json"))
model = SetFitModel.from_pretrained(MODEL_ID)
def format_evidence(ev: dict) -> str:
parts = []
if ev.get("years"):
parts.append("**Years found:** " + ", ".join(ev["years"]))
if ev.get("keyword_hits"):
for bucket, keys in ev["keyword_hits"].items():
if keys:
parts.append(f"**{bucket}:** " + ", ".join(keys))
return "\n\n".join(parts) if parts else "_No explicit time clues found._"
def predict(txt: str):
txt = (txt or "").strip()
if not txt:
return "β", "Paste some text.", {}
probs = np.asarray(model.predict_proba([txt])[0], dtype=float).ravel()
if probs.size == 0:
return "β", "Model returned no probabilities.", {}
if probs.size != len(LABELS):
return "β", f"Label mismatch: model has {probs.size} classes, labels.json has {len(LABELS)}", {}
order = np.argsort(probs)[::-1]
top1 = probs[order[0]]
top2 = probs[order[1]] if probs.size > 1 else 0.0
ev = extract_evidence(txt)
# uncertain mode
if top1 < UNCERTAINTY_THRESHOLD or (top1 - top2) < MARGIN_THRESHOLD:
topk = [{"label": LABELS[i], "score": float(probs[i])} for i in order[:TOP_K]]
md = "**Uncertain** β top candidates:\n" + "\n".join(
f"- **{d['label']}**: {d['score']:.3f}" for d in topk
)
return "uncertain", md + "\n\n" + format_evidence(ev), {LABELS[i]: float(probs[i]) for i in order}
# confident
best = LABELS[order[0]]
md = "**Reasoning hints**\n\n" + format_evidence(ev)
return best, md, {LABELS[i]: float(probs[i]) for i in order}
with gr.Blocks(title="Text β Time Period (SetFit)") as demo:
gr.Markdown("# Text β Time Period (SetFit)")
with gr.Row():
text = gr.Textbox(lines=8, label="Paste text")
with gr.Column():
pred = gr.Label(label="Predicted")
reason = gr.Markdown(label="Evidence")
scores = gr.JSON(label="Scores")
btn = gr.Button("Submit", variant="primary")
# Stable API route name; HTTP endpoint will be /api/predict
btn.click(
predict,
inputs=text,
outputs=[pred, reason, scores],
api_name="predict"
)
gr.Examples(
examples=[
"Schools went remote during the pandemic; everyone wore N95s and used Zoom.",
"Sputnik launched and kicked off the space race.",
"MySpace was the most popular social network for a while.",
"Creators blew up on TikTok; companies rolled out ChatGPT-powered tools.",
],
inputs=text
)
if __name__ == "__main__":
demo.launch()
|