DelaliScratchwerk's picture
Update app.py
8ac19c7 verified
import json, numpy as np, gradio as gr
from setfit import SetFitModel
from huggingface_hub import hf_hub_download
from evidence import extract_evidence
import os, shutil, pathlib
# Optional: only clear cache if you set CLEAR_HF_CACHE=1 in the Space secrets
if os.getenv("CLEAR_HF_CACHE") == "1":
CACHE_DIR = os.path.expanduser("~/.cache/huggingface")
shutil.rmtree(CACHE_DIR, ignore_errors=True)
pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
MODEL_ID = "DelaliScratchwerk/text-period-setfit"
# thresholds (your tuned values)
TOP_K = 3
UNCERTAINTY_THRESHOLD = 0.516
MARGIN_THRESHOLD = 0.387
# labels (Hub -> local fallback)
try:
labels_path = hf_hub_download(MODEL_ID, "labels.json")
LABELS = json.load(open(labels_path))
except Exception:
LABELS = json.load(open("labels.json"))
model = SetFitModel.from_pretrained(MODEL_ID)
def format_evidence(ev: dict) -> str:
parts = []
if ev.get("years"):
parts.append("**Years found:** " + ", ".join(ev["years"]))
if ev.get("keyword_hits"):
for bucket, keys in ev["keyword_hits"].items():
if keys:
parts.append(f"**{bucket}:** " + ", ".join(keys))
return "\n\n".join(parts) if parts else "_No explicit time clues found._"
def predict(txt: str):
txt = (txt or "").strip()
if not txt:
return "β€”", "Paste some text.", {}
probs = np.asarray(model.predict_proba([txt])[0], dtype=float).ravel()
if probs.size == 0:
return "β€”", "Model returned no probabilities.", {}
if probs.size != len(LABELS):
return "β€”", f"Label mismatch: model has {probs.size} classes, labels.json has {len(LABELS)}", {}
order = np.argsort(probs)[::-1]
top1 = probs[order[0]]
top2 = probs[order[1]] if probs.size > 1 else 0.0
ev = extract_evidence(txt)
# uncertain mode
if top1 < UNCERTAINTY_THRESHOLD or (top1 - top2) < MARGIN_THRESHOLD:
topk = [{"label": LABELS[i], "score": float(probs[i])} for i in order[:TOP_K]]
md = "**Uncertain** β€” top candidates:\n" + "\n".join(
f"- **{d['label']}**: {d['score']:.3f}" for d in topk
)
return "uncertain", md + "\n\n" + format_evidence(ev), {LABELS[i]: float(probs[i]) for i in order}
# confident
best = LABELS[order[0]]
md = "**Reasoning hints**\n\n" + format_evidence(ev)
return best, md, {LABELS[i]: float(probs[i]) for i in order}
with gr.Blocks(title="Text β†’ Time Period (SetFit)") as demo:
gr.Markdown("# Text β†’ Time Period (SetFit)")
with gr.Row():
text = gr.Textbox(lines=8, label="Paste text")
with gr.Column():
pred = gr.Label(label="Predicted")
reason = gr.Markdown(label="Evidence")
scores = gr.JSON(label="Scores")
btn = gr.Button("Submit", variant="primary")
# Stable API route name; HTTP endpoint will be /api/predict
btn.click(
predict,
inputs=text,
outputs=[pred, reason, scores],
api_name="predict"
)
gr.Examples(
examples=[
"Schools went remote during the pandemic; everyone wore N95s and used Zoom.",
"Sputnik launched and kicked off the space race.",
"MySpace was the most popular social network for a while.",
"Creators blew up on TikTok; companies rolled out ChatGPT-powered tools.",
],
inputs=text
)
if __name__ == "__main__":
demo.launch()