import gradio as gr from transformers import pipeline # Load zero-shot classifier classifier = pipeline( "zero-shot-classification", model="facebook/bart-large-mnli" ) # Default routing labels (edit freely) DEFAULT_LABELS = [ "chat", "search", "image_generation", "code", "research", "study", "project", "action" ] def classify(text, labels_text): if not text.strip(): return {} labels = [l.strip() for l in labels_text.split(",") if l.strip()] if not labels: labels = DEFAULT_LABELS result = classifier(text, labels, multi_label=False) return { "text": text, "top_intent": result["labels"][0], "scores": dict(zip(result["labels"], result["scores"])) } # -------- Gradio UI -------- with gr.Blocks(title="Zero-Shot Router (BART-MNLI)") as demo: gr.Markdown("## 🧠 Zero-Shot Intent Router") gr.Markdown( "Classifies **any input** into your routing labels.\n" "Used for **system-prompt injection + MPC routing**." ) text_input = gr.Textbox( label="User Input", placeholder="e.g. Generate a cyberpunk city wallpaper in 4k" ) labels_input = gr.Textbox( label="Labels (comma-separated)", value="chat, search, image_generation, code, research, study, project, action" ) output = gr.JSON(label="Classification Result") classify_btn = gr.Button("Classify") classify_btn.click( fn=classify, inputs=[text_input, labels_input], outputs=output ) # Expose BOTH UI + API app = demo if __name__ == "__main__": demo.launch()