Aadhavan12344 commited on
Commit
0d60a86
·
verified ·
1 Parent(s): 8c918ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -52
app.py CHANGED
@@ -1,70 +1,143 @@
 
 
 
 
1
  import gradio as gr
2
- from transformers import pipeline
 
 
3
 
4
- # Load zero-shot classifier
5
- classifier = pipeline(
6
- "zero-shot-classification",
7
- model="valhalla/distilbart-mnli-12-3"
8
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Default routing labels (edit freely)
11
- DEFAULT_LABELS = [
12
- "chat",
13
- "search",
14
- "image_generation",
15
- "code",
16
- "research",
17
- "study",
18
- "project",
19
- "action"
20
- ]
21
 
22
- def classify(text, labels_text):
23
- if not text.strip():
24
- return {}
25
 
26
- labels = [l.strip() for l in labels_text.split(",") if l.strip()]
27
- if not labels:
28
- labels = DEFAULT_LABELS
29
 
30
- result = classifier(text, labels, multi_label=False)
 
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  return {
33
  "text": text,
34
- "top_intent": result["labels"][0],
35
- "scores": dict(zip(result["labels"], result["scores"]))
 
36
  }
37
 
38
- # -------- Gradio UI --------
39
- with gr.Blocks(title="Zero-Shot Router (BART-MNLI)") as demo:
40
- gr.Markdown("## 🧠 Zero-Shot Intent Router")
41
- gr.Markdown(
42
- "Classifies **any input** into your routing labels.\n"
43
- "Used for **system-prompt injection + MPC routing**."
44
- )
45
 
46
- text_input = gr.Textbox(
47
- label="User Input",
48
- placeholder="e.g. Generate a cyberpunk city wallpaper in 4k"
49
- )
50
 
51
- labels_input = gr.Textbox(
52
- label="Labels (comma-separated)",
53
- value="chat, search, image_generation, code, research, study, project, action"
54
- )
55
 
56
- output = gr.JSON(label="Classification Result")
 
 
57
 
58
- classify_btn = gr.Button("Classify")
 
 
59
 
60
- classify_btn.click(
61
- fn=classify,
62
- inputs=[text_input, labels_input],
63
- outputs=output
64
- )
 
 
65
 
66
- # Expose BOTH UI + API
67
- app = demo
 
68
 
69
- if __name__ == "__main__":
70
- demo.launch()
 
1
+ import re
2
+ import numpy as np
3
+ from typing import Dict
4
+
5
  import gradio as gr
6
+ from fastapi import FastAPI
7
+ from sentence_transformers import SentenceTransformer
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
 
10
+ # -------------------------
11
+ # CONFIG
12
+ # -------------------------
13
+
14
+ MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
15
+ CONFIDENCE_THRESHOLD = 0.35
16
+
17
+ # -------------------------
18
+ # LOAD MODEL
19
+ # -------------------------
20
+
21
+ embedder = SentenceTransformer(MODEL_NAME)
22
+
23
+ # -------------------------
24
+ # RULE-BASED ROUTER
25
+ # -------------------------
26
+
27
+ GREETINGS = {
28
+ "hi", "hello", "hey", "yo", "sup", "hola", "hii", "hai"
29
+ }
30
+
31
+ IMAGE_KEYWORDS = {
32
+ "draw", "image", "picture", "photo", "generate image", "create image"
33
+ }
34
+
35
+ CODE_KEYWORDS = {
36
+ "code", "python", "javascript", "bug", "error", "compile", "program"
37
+ }
38
+
39
+ def rule_based_route(text: str):
40
+ t = text.lower().strip()
41
+
42
+ if t in GREETINGS:
43
+ return "chat"
44
 
45
+ if any(k in t for k in IMAGE_KEYWORDS):
46
+ return "image_generation"
 
 
 
 
 
 
 
 
 
47
 
48
+ if any(k in t for k in CODE_KEYWORDS):
49
+ return "code"
 
50
 
51
+ return None # fallback to semantic router
 
 
52
 
53
+ # -------------------------
54
+ # SEMANTIC INTENTS
55
+ # -------------------------
56
 
57
+ INTENTS: Dict[str, str] = {
58
+ "chat": "casual conversation, greetings, talking",
59
+ "search": "asking for information or facts",
60
+ "image_generation": "requesting image creation or visual generation",
61
+ "code": "programming, software development, debugging",
62
+ "research": "deep technical or academic research",
63
+ "study": "learning, studying, explanations, tutorials",
64
+ "project": "building, planning, or creating a project",
65
+ "action": "asking the system to perform an action"
66
+ }
67
+
68
+ INTENT_NAMES = list(INTENTS.keys())
69
+ INTENT_EMBEDDINGS = embedder.encode(list(INTENTS.values()), normalize_embeddings=True)
70
+
71
+ # -------------------------
72
+ # SEMANTIC ROUTER
73
+ # -------------------------
74
+
75
+ def semantic_route(text: str):
76
+ text_emb = embedder.encode([text], normalize_embeddings=True)
77
+ sims = cosine_similarity(text_emb, INTENT_EMBEDDINGS)[0]
78
+
79
+ scores = dict(zip(INTENT_NAMES, sims))
80
+ top_intent = max(scores, key=scores.get)
81
+ confidence = scores[top_intent]
82
+
83
+ if confidence < CONFIDENCE_THRESHOLD:
84
+ return "chat", scores # safe fallback
85
+
86
+ return top_intent, scores
87
+
88
+ # -------------------------
89
+ # MAIN ROUTER
90
+ # -------------------------
91
+
92
+ def route_intent(text: str):
93
+ # 1️⃣ Rule-based (instant)
94
+ rule = rule_based_route(text)
95
+ if rule:
96
+ return {
97
+ "text": text,
98
+ "top_intent": rule,
99
+ "method": "rule",
100
+ "scores": {rule: 1.0}
101
+ }
102
+
103
+ # 2️⃣ Semantic
104
+ intent, scores = semantic_route(text)
105
  return {
106
  "text": text,
107
+ "top_intent": intent,
108
+ "method": "semantic",
109
+ "scores": scores
110
  }
111
 
112
+ # -------------------------
113
+ # FASTAPI
114
+ # -------------------------
 
 
 
 
115
 
116
+ api = FastAPI()
 
 
 
117
 
118
+ @api.post("/classify")
119
+ def classify(payload: Dict):
120
+ text = payload.get("text", "")
121
+ return route_intent(text)
122
 
123
+ # -------------------------
124
+ # GRADIO UI
125
+ # -------------------------
126
 
127
+ def gradio_classify(text):
128
+ result = route_intent(text)
129
+ return result
130
 
131
+ gradio_ui = gr.Interface(
132
+ fn=gradio_classify,
133
+ inputs=gr.Textbox(label="User Input"),
134
+ outputs=gr.JSON(label="Classification Result"),
135
+ title="🧠 Hybrid Intent Router",
136
+ description="Rule-based + Semantic intent classification for prompt routing & MPC selection"
137
+ )
138
 
139
+ # -------------------------
140
+ # MOUNT GRADIO INTO FASTAPI
141
+ # -------------------------
142
 
143
+ app = gr.mount_gradio_app(api, gradio_ui, path="/")