StaticFace commited on
Commit
a628211
·
verified ·
1 Parent(s): 0a03f17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -17,7 +17,8 @@ clf = pipeline(
17
  task="zero-shot-classification",
18
  model=model,
19
  tokenizer=tokenizer,
20
- device=-1
 
21
  )
22
 
23
  def run_zero_shot(text, labels, hypothesis_template, multi_label, top_k):
@@ -27,6 +28,7 @@ def run_zero_shot(text, labels, hypothesis_template, multi_label, top_k):
27
 
28
  if not text:
29
  return {"error": "Enter some text."}
 
30
  candidate_labels = [x.strip() for x in labels.split(",") if x.strip()]
31
  if not candidate_labels:
32
  return {"error": "Enter at least 1 label (comma-separated)."}
@@ -36,7 +38,7 @@ def run_zero_shot(text, labels, hypothesis_template, multi_label, top_k):
36
  sequences=text,
37
  candidate_labels=candidate_labels,
38
  hypothesis_template=hypothesis_template,
39
- multi_label=bool(multi_label)
40
  )
41
 
42
  pairs = list(zip(out["labels"], out["scores"]))
@@ -46,7 +48,7 @@ def run_zero_shot(text, labels, hypothesis_template, multi_label, top_k):
46
  return {
47
  "top": {"label": pairs[0][0], "confidence_pct": round(pairs[0][1] * 100, 2)},
48
  "all": [{"label": k, "confidence_pct": round(v * 100, 2)} for k, v in pairs],
49
- "raw": out
50
  }
51
 
52
  demo = gr.Interface(
@@ -60,7 +62,7 @@ demo = gr.Interface(
60
  ],
61
  outputs=gr.JSON(label="Output"),
62
  title="Zero-Shot Classification (DeBERTa v3 Large, MoritzLaurer)",
63
- allow_flagging="never"
64
  )
65
 
66
  if __name__ == "__main__":
 
17
  task="zero-shot-classification",
18
  model=model,
19
  tokenizer=tokenizer,
20
+ device=-1,
21
+ framework="pt",
22
  )
23
 
24
  def run_zero_shot(text, labels, hypothesis_template, multi_label, top_k):
 
28
 
29
  if not text:
30
  return {"error": "Enter some text."}
31
+
32
  candidate_labels = [x.strip() for x in labels.split(",") if x.strip()]
33
  if not candidate_labels:
34
  return {"error": "Enter at least 1 label (comma-separated)."}
 
38
  sequences=text,
39
  candidate_labels=candidate_labels,
40
  hypothesis_template=hypothesis_template,
41
+ multi_label=bool(multi_label),
42
  )
43
 
44
  pairs = list(zip(out["labels"], out["scores"]))
 
48
  return {
49
  "top": {"label": pairs[0][0], "confidence_pct": round(pairs[0][1] * 100, 2)},
50
  "all": [{"label": k, "confidence_pct": round(v * 100, 2)} for k, v in pairs],
51
+ "raw": out,
52
  }
53
 
54
  demo = gr.Interface(
 
62
  ],
63
  outputs=gr.JSON(label="Output"),
64
  title="Zero-Shot Classification (DeBERTa v3 Large, MoritzLaurer)",
65
+ flagging_mode="never",
66
  )
67
 
68
  if __name__ == "__main__":