DelaliScratchwerk commited on
Commit
d5706a4
Β·
verified Β·
1 Parent(s): d220598

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -40
app.py CHANGED
@@ -1,58 +1,37 @@
1
- import gradio as gr
2
- import numpy as np
3
  from setfit import SetFitModel
4
 
5
- MODEL_ID = "DelaliScratchwerk/text-period-setfit" # <- your model repo
6
 
7
- # keep this order the same as in training
8
- LABELS = [
9
- "pre-1900","1900–1945","1946–1990","1991–2008",
10
- "2009–2015","2016–2018","2019–2022","2023–present"
11
- ]
12
 
13
  model = SetFitModel.from_pretrained(MODEL_ID)
14
 
15
  def predict(txt: str):
16
- # Guard empty input
17
- if not txt or not txt.strip():
18
  return "β€”", {"error": "Please paste some text."}
19
-
20
- # Get probabilities and force to 1D float array
21
- out = model.predict_proba([txt])[0] # could be list/ndarray/scalar
22
- probs = np.asarray(out, dtype=float).ravel()
23
-
24
- # Guard weird shapes
25
- if probs.size == 0:
26
- return "β€”", {"error": "Model returned no scores."}
27
-
28
- # If LABELS length doesn't match, truncate/pad just to avoid crashes
29
- if len(LABELS) != probs.size:
30
- lbls = (LABELS + [f"class_{i}" for i in range(probs.size)])[:probs.size]
31
- else:
32
- lbls = LABELS
33
-
34
- order = np.argsort(probs)[::-1] # descending
35
- top_label = lbls[int(order[0])]
36
- scores = {lbls[int(i)]: float(probs[int(i)]) for i in order}
37
-
38
- return top_label, scores
39
-
40
- examples = [
41
- "Schools went remote during the pandemic; everyone wore N95s and used Zoom.",
42
- "Sputnik launched and kicked off the space race.",
43
- "MySpace was the most popular social network for a while.",
44
- "TikTok creators exploded in popularity.",
45
- ]
46
 
47
  demo = gr.Interface(
48
  fn=predict,
49
  inputs=gr.Textbox(lines=8, label="Paste text"),
50
  outputs=[gr.Label(label="Predicted Period"), gr.JSON(label="Scores")],
51
  title="Text β†’ Time Period (SetFit)",
52
- examples=examples,
53
- cache_examples=False, # <β€” disable startup caching to avoid 500
 
 
 
 
 
54
  allow_flagging="never",
55
  )
56
-
57
  if __name__ == "__main__":
58
  demo.launch()
 
1
+ import json, numpy as np, gradio as gr
 
2
  from setfit import SetFitModel
3
 
4
+ MODEL_ID = "DelaliScratchwerk/text-period-setfit"
5
 
6
+ with open("labels.json") as f:
7
+ LABELS = json.load(f)
 
 
 
8
 
9
  model = SetFitModel.from_pretrained(MODEL_ID)
10
 
11
  def predict(txt: str):
12
+ if not txt.strip():
 
13
  return "β€”", {"error": "Please paste some text."}
14
+ probs = np.asarray(model.predict_proba([txt])[0], dtype=float).ravel()
15
+ if probs.size != len(LABELS):
16
+ # hard stop so you notice a mismatch
17
+ return "β€”", {"error": f'label mismatch: model has {probs.size} classes, labels.json has {len(LABELS)}'}
18
+ order = np.argsort(probs)[::-1]
19
+ top = LABELS[int(order[0])]
20
+ return top, {LABELS[int(i)]: float(probs[int(i)]) for i in order}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  demo = gr.Interface(
23
  fn=predict,
24
  inputs=gr.Textbox(lines=8, label="Paste text"),
25
  outputs=[gr.Label(label="Predicted Period"), gr.JSON(label="Scores")],
26
  title="Text β†’ Time Period (SetFit)",
27
+ examples=[
28
+ "Schools went remote during the pandemic; everyone wore N95s and used Zoom.",
29
+ "Sputnik launched and kicked off the space race.",
30
+ "MySpace was the most popular social network for a while.",
31
+ "TikTok creators exploded in popularity.",
32
+ ],
33
+ cache_examples=False,
34
  allow_flagging="never",
35
  )
 
36
  if __name__ == "__main__":
37
  demo.launch()