Aadhavan12344's picture
Create app.py
de5768e verified
raw
history blame
1.65 kB
import gradio as gr
from transformers import pipeline
# Load zero-shot classifier
classifier = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli"
)
# Default routing labels (edit freely)
DEFAULT_LABELS = [
"chat",
"search",
"image_generation",
"code",
"research",
"study",
"project",
"action"
]
def classify(text, labels_text):
if not text.strip():
return {}
labels = [l.strip() for l in labels_text.split(",") if l.strip()]
if not labels:
labels = DEFAULT_LABELS
result = classifier(text, labels, multi_label=False)
return {
"text": text,
"top_intent": result["labels"][0],
"scores": dict(zip(result["labels"], result["scores"]))
}
# -------- Gradio UI --------
with gr.Blocks(title="Zero-Shot Router (BART-MNLI)") as demo:
gr.Markdown("## 🧠 Zero-Shot Intent Router")
gr.Markdown(
"Classifies **any input** into your routing labels.\n"
"Used for **system-prompt injection + MPC routing**."
)
text_input = gr.Textbox(
label="User Input",
placeholder="e.g. Generate a cyberpunk city wallpaper in 4k"
)
labels_input = gr.Textbox(
label="Labels (comma-separated)",
value="chat, search, image_generation, code, research, study, project, action"
)
output = gr.JSON(label="Classification Result")
classify_btn = gr.Button("Classify")
classify_btn.click(
fn=classify,
inputs=[text_input, labels_input],
outputs=output
)
# Expose BOTH UI + API
app = demo
if __name__ == "__main__":
demo.launch()