File size: 3,449 Bytes
ac77253
6a0dda2
3133f26
e8485c8
d6c5a54
0a06046
d6c5a54
 
 
 
 
a3ce5e8
d5706a4
f327d31
d6c5a54
980f96f
d6c5a54
 
980f96f
d6c5a54
f327d31
 
 
3133f26
f327d31
1c53ee5
 
 
d6c5a54
e8485c8
 
 
 
d6c5a54
 
 
e8485c8
 
1c53ee5
d6c5a54
 
e8485c8
d6c5a54
d5706a4
d6c5a54
 
e8485c8
d5706a4
e8485c8
 
d5706a4
6cbec58
 
e8485c8
 
 
 
6cbec58
 
d6c5a54
e8485c8
 
 
 
 
6cbec58
e8485c8
 
 
 
 
 
 
 
 
 
 
 
8ac19c7
 
 
 
 
 
 
e8485c8
 
 
 
 
 
 
 
 
 
3133f26
a3ce5e8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import json, numpy as np, gradio as gr
from setfit import SetFitModel
from huggingface_hub import hf_hub_download
from evidence import extract_evidence
import os, shutil, pathlib

# Optional: only clear cache if you set CLEAR_HF_CACHE=1 in the Space secrets
if os.getenv("CLEAR_HF_CACHE") == "1":
    CACHE_DIR = os.path.expanduser("~/.cache/huggingface")
    shutil.rmtree(CACHE_DIR, ignore_errors=True)
    pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)

MODEL_ID = "DelaliScratchwerk/text-period-setfit"

# thresholds (your tuned values)
TOP_K = 3
UNCERTAINTY_THRESHOLD = 0.516
MARGIN_THRESHOLD = 0.387

# labels (Hub -> local fallback)
try:
    labels_path = hf_hub_download(MODEL_ID, "labels.json")
    LABELS = json.load(open(labels_path))
except Exception:
    LABELS = json.load(open("labels.json"))

model = SetFitModel.from_pretrained(MODEL_ID)

def format_evidence(ev: dict) -> str:
    parts = []
    if ev.get("years"):
        parts.append("**Years found:** " + ", ".join(ev["years"]))
    if ev.get("keyword_hits"):
        for bucket, keys in ev["keyword_hits"].items():
            if keys:
                parts.append(f"**{bucket}:** " + ", ".join(keys))
    return "\n\n".join(parts) if parts else "_No explicit time clues found._"

def predict(txt: str):
    txt = (txt or "").strip()
    if not txt:
        return "β€”", "Paste some text.", {}

    probs = np.asarray(model.predict_proba([txt])[0], dtype=float).ravel()
    if probs.size == 0:
        return "β€”", "Model returned no probabilities.", {}

    if probs.size != len(LABELS):
        return "β€”", f"Label mismatch: model has {probs.size} classes, labels.json has {len(LABELS)}", {}

    order = np.argsort(probs)[::-1]
    top1 = probs[order[0]]
    top2 = probs[order[1]] if probs.size > 1 else 0.0
    ev = extract_evidence(txt)

    # uncertain mode
    if top1 < UNCERTAINTY_THRESHOLD or (top1 - top2) < MARGIN_THRESHOLD:
        topk = [{"label": LABELS[i], "score": float(probs[i])} for i in order[:TOP_K]]
        md = "**Uncertain** β€” top candidates:\n" + "\n".join(
            f"- **{d['label']}**: {d['score']:.3f}" for d in topk
        )
        return "uncertain", md + "\n\n" + format_evidence(ev), {LABELS[i]: float(probs[i]) for i in order}

    # confident
    best = LABELS[order[0]]
    md = "**Reasoning hints**\n\n" + format_evidence(ev)
    return best, md, {LABELS[i]: float(probs[i]) for i in order}

with gr.Blocks(title="Text β†’ Time Period (SetFit)") as demo:
    gr.Markdown("# Text β†’ Time Period (SetFit)")
    with gr.Row():
        text = gr.Textbox(lines=8, label="Paste text")
        with gr.Column():
            pred = gr.Label(label="Predicted")
            reason = gr.Markdown(label="Evidence")
    scores = gr.JSON(label="Scores")

    btn = gr.Button("Submit", variant="primary")
    # Stable API route name; HTTP endpoint will be /api/predict
    btn.click(
        predict,
        inputs=text,
        outputs=[pred, reason, scores],
        api_name="predict"
    )

    gr.Examples(
        examples=[
            "Schools went remote during the pandemic; everyone wore N95s and used Zoom.",
            "Sputnik launched and kicked off the space race.",
            "MySpace was the most popular social network for a while.",
            "Creators blew up on TikTok; companies rolled out ChatGPT-powered tools.",
        ],
        inputs=text
    )

if __name__ == "__main__":
    demo.launch()