DelaliScratchwerk commited on
Commit
1c53ee5
·
verified ·
1 Parent(s): 903aed0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -9
app.py CHANGED
@@ -1,21 +1,36 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
3
 
 
4
  MODEL_ID = "DelaliScratchwerk/text-period-setfit"
5
- pipe = pipeline("text-classification", model=MODEL_ID, return_all_scores=True)
6
 
7
- def predict(txt):
8
- scores = pipe(txt)[0]
9
- scores = sorted(scores, key=lambda d: d["score"], reverse=True)
10
- top = scores[0]["label"]
11
- table = {s["label"]: round(float(s["score"]), 4) for s in scores}
12
- return top, table
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  demo = gr.Interface(
15
  fn=predict,
16
  inputs=gr.Textbox(lines=8, label="Paste text"),
17
  outputs=[gr.Label(label="Predicted Period"), gr.JSON(label="Scores")],
18
- title="Text → Time Period"
 
19
  )
20
 
21
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from setfit import SetFitModel
3
+ import numpy as np
4
 
5
+ # 👇 use your exact model repo
6
  MODEL_ID = "DelaliScratchwerk/text-period-setfit"
 
7
 
8
+ # The label order must match the order you used during training
9
+ LABELS = ["pre-1900","1900–1945","1946–1990","1991–2008","2009–2015","2016–2018","2019–2022","2023–present"]
10
+
11
+ model = SetFitModel.from_pretrained(MODEL_ID)
12
+
13
+ def predict(txt: str):
14
+ # SetFit expects a list; returns probabilities per label
15
+ probs = model.predict_proba([txt])[0] # shape: (num_labels,)
16
+ order = np.argsort(probs)[::-1] # descending
17
+ top_label = LABELS[order[0]]
18
+ table = {LABELS[i]: float(probs[i]) for i in order}
19
+ return top_label, table
20
+
21
+ examples = [
22
+ "Schools went remote during the pandemic; everyone wore N95s and used Zoom.",
23
+ "Sputnik launched and kicked off the space race.",
24
+ "MySpace was the most popular social network for a while.",
25
+ "TikTok creators exploded in popularity.",
26
+ ]
27
 
28
  demo = gr.Interface(
29
  fn=predict,
30
  inputs=gr.Textbox(lines=8, label="Paste text"),
31
  outputs=[gr.Label(label="Predicted Period"), gr.JSON(label="Scores")],
32
+ title="Text → Time Period (SetFit)",
33
+ examples=examples,
34
  )
35
 
36
  if __name__ == "__main__":