DelaliScratchwerk commited on
Commit
65ba28a
·
verified ·
1 Parent(s): 6d07233

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -1,10 +1,9 @@
1
  import json, numpy as np, gradio as gr
2
  from setfit import SetFitModel
 
3
 
4
  MODEL_ID = "DelaliScratchwerk/text-period-setfit"
5
-
6
- with open("labels.json") as f:
7
- LABELS = json.load(f)
8
 
9
  model = SetFitModel.from_pretrained(MODEL_ID)
10
 
@@ -13,8 +12,7 @@ def predict(txt: str):
13
  return "—", {"error": "Please paste some text."}
14
  probs = np.asarray(model.predict_proba([txt])[0], dtype=float).ravel()
15
  if probs.size != len(LABELS):
16
- # hard stop so you notice a mismatch
17
- return "—", {"error": f'label mismatch: model has {probs.size} classes, labels.json has {len(LABELS)}'}
18
  order = np.argsort(probs)[::-1]
19
  top = LABELS[int(order[0])]
20
  return top, {LABELS[int(i)]: float(probs[int(i)]) for i in order}
 
1
  import json, numpy as np, gradio as gr
2
  from setfit import SetFitModel
3
+ from huggingface_hub import hf_hub_download
4
 
5
  MODEL_ID = "DelaliScratchwerk/text-period-setfit"
6
+ LABELS = json.load(open(hf_hub_download(MODEL_ID, "labels.json"))) # pulls from model repo
 
 
7
 
8
  model = SetFitModel.from_pretrained(MODEL_ID)
9
 
 
12
  return "—", {"error": "Please paste some text."}
13
  probs = np.asarray(model.predict_proba([txt])[0], dtype=float).ravel()
14
  if probs.size != len(LABELS):
15
+ return "—", {"error": f"label mismatch: model has {probs.size}, labels.json has {len(LABELS)}"}
 
16
  order = np.argsort(probs)[::-1]
17
  top = LABELS[int(order[0])]
18
  return top, {LABELS[int(i)]: float(probs[int(i)]) for i in order}