DelaliScratchwerk commited on
Commit
6a0dda2
Β·
verified Β·
1 Parent(s): 1c53ee5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -11
app.py CHANGED
@@ -1,22 +1,41 @@
1
  import gradio as gr
2
- from setfit import SetFitModel
3
  import numpy as np
 
4
 
5
- # πŸ‘‡ use your exact model repo
6
- MODEL_ID = "DelaliScratchwerk/text-period-setfit"
7
 
8
- # The label order must match the order you used during training
9
- LABELS = ["pre-1900","1900–1945","1946–1990","1991–2008","2009–2015","2016–2018","2019–2022","2023–present"]
 
 
 
10
 
11
  model = SetFitModel.from_pretrained(MODEL_ID)
12
 
13
  def predict(txt: str):
14
- # SetFit expects a list; returns probabilities per label
15
- probs = model.predict_proba([txt])[0] # shape: (num_labels,)
16
- order = np.argsort(probs)[::-1] # descending
17
- top_label = LABELS[order[0]]
18
- table = {LABELS[i]: float(probs[i]) for i in order}
19
- return top_label, table
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  examples = [
22
  "Schools went remote during the pandemic; everyone wore N95s and used Zoom.",
@@ -31,6 +50,8 @@ demo = gr.Interface(
31
  outputs=[gr.Label(label="Predicted Period"), gr.JSON(label="Scores")],
32
  title="Text β†’ Time Period (SetFit)",
33
  examples=examples,
 
 
34
  )
35
 
36
  if __name__ == "__main__":
 
1
  import gradio as gr
 
2
  import numpy as np
3
+ from setfit import SetFitModel
4
 
5
+ MODEL_ID = "DelaliScratchwerk/text-period-setfit" # <- your model repo
 
6
 
7
+ # keep this order the same as in training
8
+ LABELS = [
9
+ "pre-1900","1900–1945","1946–1990","1991–2008",
10
+ "2009–2015","2016–2018","2019–2022","2023–present"
11
+ ]
12
 
13
  model = SetFitModel.from_pretrained(MODEL_ID)
14
 
15
  def predict(txt: str):
16
+ # Guard empty input
17
+ if not txt or not txt.strip():
18
+ return "β€”", {"error": "Please paste some text."}
19
+
20
+ # Get probabilities and force to 1D float array
21
+ out = model.predict_proba([txt])[0] # could be list/ndarray/scalar
22
+ probs = np.asarray(out, dtype=float).ravel()
23
+
24
+ # Guard weird shapes
25
+ if probs.size == 0:
26
+ return "β€”", {"error": "Model returned no scores."}
27
+
28
+ # If LABELS length doesn't match, truncate/pad just to avoid crashes
29
+ if len(LABELS) != probs.size:
30
+ lbls = (LABELS + [f"class_{i}" for i in range(probs.size)])[:probs.size]
31
+ else:
32
+ lbls = LABELS
33
+
34
+ order = np.argsort(probs)[::-1] # descending
35
+ top_label = lbls[int(order[0])]
36
+ scores = {lbls[int(i)]: float(probs[int(i)]) for i in order}
37
+
38
+ return top_label, scores
39
 
40
  examples = [
41
  "Schools went remote during the pandemic; everyone wore N95s and used Zoom.",
 
50
  outputs=[gr.Label(label="Predicted Period"), gr.JSON(label="Scores")],
51
  title="Text β†’ Time Period (SetFit)",
52
  examples=examples,
53
+ cache_examples=False, # <β€” disable startup caching to avoid 500
54
+ allow_flagging="never",
55
  )
56
 
57
  if __name__ == "__main__":