Spaces:
Sleeping
Sleeping
File size: 1,649 Bytes
de5768e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import gradio as gr
from transformers import pipeline
# Load zero-shot classifier
classifier = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli"
)
# Default routing labels (edit freely)
DEFAULT_LABELS = [
"chat",
"search",
"image_generation",
"code",
"research",
"study",
"project",
"action"
]
def classify(text, labels_text):
if not text.strip():
return {}
labels = [l.strip() for l in labels_text.split(",") if l.strip()]
if not labels:
labels = DEFAULT_LABELS
result = classifier(text, labels, multi_label=False)
return {
"text": text,
"top_intent": result["labels"][0],
"scores": dict(zip(result["labels"], result["scores"]))
}
# -------- Gradio UI --------
with gr.Blocks(title="Zero-Shot Router (BART-MNLI)") as demo:
gr.Markdown("## 🧠 Zero-Shot Intent Router")
gr.Markdown(
"Classifies **any input** into your routing labels.\n"
"Used for **system-prompt injection + MPC routing**."
)
text_input = gr.Textbox(
label="User Input",
placeholder="e.g. Generate a cyberpunk city wallpaper in 4k"
)
labels_input = gr.Textbox(
label="Labels (comma-separated)",
value="chat, search, image_generation, code, research, study, project, action"
)
output = gr.JSON(label="Classification Result")
classify_btn = gr.Button("Classify")
classify_btn.click(
fn=classify,
inputs=[text_input, labels_input],
outputs=output
)
# Expose BOTH UI + API
app = demo
if __name__ == "__main__":
demo.launch()
|