|
|
import json, numpy as np, gradio as gr |
|
|
from setfit import SetFitModel |
|
|
from huggingface_hub import hf_hub_download |
|
|
from evidence import extract_evidence |
|
|
import os, shutil, pathlib |
|
|
|
|
|
|
|
|
if os.getenv("CLEAR_HF_CACHE") == "1": |
|
|
CACHE_DIR = os.path.expanduser("~/.cache/huggingface") |
|
|
shutil.rmtree(CACHE_DIR, ignore_errors=True) |
|
|
pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
MODEL_ID = "DelaliScratchwerk/text-period-setfit" |
|
|
|
|
|
|
|
|
TOP_K = 3 |
|
|
UNCERTAINTY_THRESHOLD = 0.516 |
|
|
MARGIN_THRESHOLD = 0.387 |
|
|
|
|
|
|
|
|
try: |
|
|
labels_path = hf_hub_download(MODEL_ID, "labels.json") |
|
|
LABELS = json.load(open(labels_path)) |
|
|
except Exception: |
|
|
LABELS = json.load(open("labels.json")) |
|
|
|
|
|
model = SetFitModel.from_pretrained(MODEL_ID) |
|
|
|
|
|
def format_evidence(ev: dict) -> str: |
|
|
parts = [] |
|
|
if ev.get("years"): |
|
|
parts.append("**Years found:** " + ", ".join(ev["years"])) |
|
|
if ev.get("keyword_hits"): |
|
|
for bucket, keys in ev["keyword_hits"].items(): |
|
|
if keys: |
|
|
parts.append(f"**{bucket}:** " + ", ".join(keys)) |
|
|
return "\n\n".join(parts) if parts else "_No explicit time clues found._" |
|
|
|
|
|
def predict(txt: str): |
|
|
txt = (txt or "").strip() |
|
|
if not txt: |
|
|
return "β", "Paste some text.", {} |
|
|
|
|
|
probs = np.asarray(model.predict_proba([txt])[0], dtype=float).ravel() |
|
|
if probs.size == 0: |
|
|
return "β", "Model returned no probabilities.", {} |
|
|
|
|
|
if probs.size != len(LABELS): |
|
|
return "β", f"Label mismatch: model has {probs.size} classes, labels.json has {len(LABELS)}", {} |
|
|
|
|
|
order = np.argsort(probs)[::-1] |
|
|
top1 = probs[order[0]] |
|
|
top2 = probs[order[1]] if probs.size > 1 else 0.0 |
|
|
ev = extract_evidence(txt) |
|
|
|
|
|
|
|
|
if top1 < UNCERTAINTY_THRESHOLD or (top1 - top2) < MARGIN_THRESHOLD: |
|
|
topk = [{"label": LABELS[i], "score": float(probs[i])} for i in order[:TOP_K]] |
|
|
md = "**Uncertain** β top candidates:\n" + "\n".join( |
|
|
f"- **{d['label']}**: {d['score']:.3f}" for d in topk |
|
|
) |
|
|
return "uncertain", md + "\n\n" + format_evidence(ev), {LABELS[i]: float(probs[i]) for i in order} |
|
|
|
|
|
|
|
|
best = LABELS[order[0]] |
|
|
md = "**Reasoning hints**\n\n" + format_evidence(ev) |
|
|
return best, md, {LABELS[i]: float(probs[i]) for i in order} |
|
|
|
|
|
with gr.Blocks(title="Text β Time Period (SetFit)") as demo: |
|
|
gr.Markdown("# Text β Time Period (SetFit)") |
|
|
with gr.Row(): |
|
|
text = gr.Textbox(lines=8, label="Paste text") |
|
|
with gr.Column(): |
|
|
pred = gr.Label(label="Predicted") |
|
|
reason = gr.Markdown(label="Evidence") |
|
|
scores = gr.JSON(label="Scores") |
|
|
|
|
|
btn = gr.Button("Submit", variant="primary") |
|
|
|
|
|
btn.click( |
|
|
predict, |
|
|
inputs=text, |
|
|
outputs=[pred, reason, scores], |
|
|
api_name="predict" |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
"Schools went remote during the pandemic; everyone wore N95s and used Zoom.", |
|
|
"Sputnik launched and kicked off the space race.", |
|
|
"MySpace was the most popular social network for a while.", |
|
|
"Creators blew up on TikTok; companies rolled out ChatGPT-powered tools.", |
|
|
], |
|
|
inputs=text |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|