File size: 1,649 Bytes
de5768e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
from transformers import pipeline

# Load zero-shot classifier
classifier = pipeline(
    "zero-shot-classification",
    model="facebook/bart-large-mnli"
)

# Default routing labels (edit freely)
DEFAULT_LABELS = [
    "chat",
    "search",
    "image_generation",
    "code",
    "research",
    "study",
    "project",
    "action"
]

def classify(text, labels_text):
    if not text.strip():
        return {}

    labels = [l.strip() for l in labels_text.split(",") if l.strip()]
    if not labels:
        labels = DEFAULT_LABELS

    result = classifier(text, labels, multi_label=False)

    return {
        "text": text,
        "top_intent": result["labels"][0],
        "scores": dict(zip(result["labels"], result["scores"]))
    }

# -------- Gradio UI --------
with gr.Blocks(title="Zero-Shot Router (BART-MNLI)") as demo:
    gr.Markdown("## 🧠 Zero-Shot Intent Router")
    gr.Markdown(
        "Classifies **any input** into your routing labels.\n"
        "Used for **system-prompt injection + MPC routing**."
    )

    text_input = gr.Textbox(
        label="User Input",
        placeholder="e.g. Generate a cyberpunk city wallpaper in 4k"
    )

    labels_input = gr.Textbox(
        label="Labels (comma-separated)",
        value="chat, search, image_generation, code, research, study, project, action"
    )

    output = gr.JSON(label="Classification Result")

    classify_btn = gr.Button("Classify")

    classify_btn.click(
        fn=classify,
        inputs=[text_input, labels_input],
        outputs=output
    )

# Expose BOTH UI + API
app = demo

if __name__ == "__main__":
    demo.launch()