Aadhavan12344 commited on
Commit
de5768e
·
verified ·
1 Parent(s): 0850ab0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+
4
+ # Load zero-shot classifier
5
+ classifier = pipeline(
6
+ "zero-shot-classification",
7
+ model="facebook/bart-large-mnli"
8
+ )
9
+
10
+ # Default routing labels (edit freely)
11
+ DEFAULT_LABELS = [
12
+ "chat",
13
+ "search",
14
+ "image_generation",
15
+ "code",
16
+ "research",
17
+ "study",
18
+ "project",
19
+ "action"
20
+ ]
21
+
22
+ def classify(text, labels_text):
23
+ if not text.strip():
24
+ return {}
25
+
26
+ labels = [l.strip() for l in labels_text.split(",") if l.strip()]
27
+ if not labels:
28
+ labels = DEFAULT_LABELS
29
+
30
+ result = classifier(text, labels, multi_label=False)
31
+
32
+ return {
33
+ "text": text,
34
+ "top_intent": result["labels"][0],
35
+ "scores": dict(zip(result["labels"], result["scores"]))
36
+ }
37
+
38
+ # -------- Gradio UI --------
39
+ with gr.Blocks(title="Zero-Shot Router (BART-MNLI)") as demo:
40
+ gr.Markdown("## 🧠 Zero-Shot Intent Router")
41
+ gr.Markdown(
42
+ "Classifies **any input** into your routing labels.\n"
43
+ "Used for **system-prompt injection + MPC routing**."
44
+ )
45
+
46
+ text_input = gr.Textbox(
47
+ label="User Input",
48
+ placeholder="e.g. Generate a cyberpunk city wallpaper in 4k"
49
+ )
50
+
51
+ labels_input = gr.Textbox(
52
+ label="Labels (comma-separated)",
53
+ value="chat, search, image_generation, code, research, study, project, action"
54
+ )
55
+
56
+ output = gr.JSON(label="Classification Result")
57
+
58
+ classify_btn = gr.Button("Classify")
59
+
60
+ classify_btn.click(
61
+ fn=classify,
62
+ inputs=[text_input, labels_input],
63
+ outputs=output
64
+ )
65
+
66
+ # Expose BOTH UI + API
67
+ app = demo
68
+
69
+ if __name__ == "__main__":
70
+ demo.launch()