StaticFace commited on
Commit
a6178d9
·
verified ·
1 Parent(s): ae1afc1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
5
+
6
+ MODEL_ID = "MoritzLaurer/deberta-v3-large-zeroshot-v2.0"
7
+
8
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
9
+ torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", "2")))
10
+ torch.set_num_interop_threads(1)
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
13
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
14
+ model.eval()
15
+
16
+ clf = pipeline(
17
+ task="zero-shot-classification",
18
+ model=model,
19
+ tokenizer=tokenizer,
20
+ device=-1
21
+ )
22
+
23
+ def run_zero_shot(text, labels, hypothesis_template, multi_label, top_k):
24
+ text = (text or "").strip()
25
+ labels = (labels or "").strip()
26
+ hypothesis_template = (hypothesis_template or "").strip() or "This text is about {}"
27
+
28
+ if not text:
29
+ return {"error": "Enter some text."}
30
+ candidate_labels = [x.strip() for x in labels.split(",") if x.strip()]
31
+ if not candidate_labels:
32
+ return {"error": "Enter at least 1 label (comma-separated)."}
33
+
34
+ with torch.inference_mode():
35
+ out = clf(
36
+ sequences=text,
37
+ candidate_labels=candidate_labels,
38
+ hypothesis_template=hypothesis_template,
39
+ multi_label=bool(multi_label)
40
+ )
41
+
42
+ pairs = list(zip(out["labels"], out["scores"]))
43
+ pairs.sort(key=lambda x: x[1], reverse=True)
44
+ pairs = pairs[: max(1, int(top_k))]
45
+
46
+ return {
47
+ "top": {"label": pairs[0][0], "confidence_pct": round(pairs[0][1] * 100, 2)},
48
+ "all": [{"label": k, "confidence_pct": round(v * 100, 2)} for k, v in pairs],
49
+ "raw": out
50
+ }
51
+
52
+ demo = gr.Interface(
53
+ fn=run_zero_shot,
54
+ inputs=[
55
+ gr.Textbox(label="Text", lines=4, value="I am wahhhh"),
56
+ gr.Textbox(label="Candidate Labels (comma-separated)", value="sad, happy, angry, neutral"),
57
+ gr.Textbox(label="Hypothesis Template", value="This text is about {}"),
58
+ gr.Checkbox(label="Multi-label", value=False),
59
+ gr.Slider(label="Top-K to show", minimum=1, maximum=25, value=5, step=1),
60
+ ],
61
+ outputs=gr.JSON(label="Output"),
62
+ title="Zero-Shot Classification (DeBERTa v3 Large, MoritzLaurer)",
63
+ allow_flagging="never"
64
+ )
65
+
66
+ if __name__ == "__main__":
67
+ demo.launch()