Aadhavan12344's picture
Update app.py
b496166 verified
# ============================================================
# ⚡ Semantic Intent Router (MiniLM)
# Zero-shot • No training • Sub-second • HF Free CPU
# ============================================================
import json
import time
import math
from typing import Dict, List, Any
import torch
import gradio as gr
from sentence_transformers import SentenceTransformer, util
# ============================================================
# CONFIG
# ============================================================
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
MIN_SCORE = 0.05
MAX_EXAMPLES = 20
# ============================================================
# LOAD MODEL
# ============================================================
print("Loading MiniLM model...")
model = SentenceTransformer(MODEL_NAME, device="cpu")
print("Model loaded")
# ============================================================
# HELPERS
# ============================================================
def softmax(scores: Dict[str, float]) -> Dict[str, float]:
if not scores:
return {}
max_val = max(scores.values())
exp_scores = {k: math.exp(v - max_val) for k, v in scores.items()}
total = sum(exp_scores.values())
return {k: v / total for k, v in exp_scores.items()}
def parse_labels(raw: Any) -> Dict[str, List[str]]:
"""
Accepts dict (Gradio JSON) or JSON string.
Returns clean label -> examples mapping.
"""
if isinstance(raw, str):
try:
raw = json.loads(raw)
except Exception as e:
return {"__error__": f"Invalid JSON: {e}"}
if not isinstance(raw, dict):
return {"__error__": "Labels must be a JSON object"}
cleaned = {}
for label, examples in raw.items():
if not isinstance(label, str):
continue
if not isinstance(examples, list):
continue
ex = [
str(x).strip()
for x in examples
if isinstance(x, (str, int, float)) and str(x).strip()
][:MAX_EXAMPLES]
if ex:
cleaned[label] = ex
if not cleaned:
return {"__error__": "No valid labels found"}
return cleaned
# ============================================================
# CLASSIFIER CORE
# ============================================================
def classify(text: str, raw_labels: Any) -> Dict[str, Any]:
start = time.time()
if not text or not text.strip():
return {"error": "Empty input"}
labels = parse_labels(raw_labels)
if "__error__" in labels:
return {"error": labels["__error__"]}
text_emb = model.encode(text, convert_to_tensor=True)
scores = {}
for label, examples in labels.items():
example_embs = model.encode(examples, convert_to_tensor=True)
sims = util.cos_sim(text_emb, example_embs)[0]
score = float(torch.max(sims).item())
if score >= MIN_SCORE:
scores[label] = score
if not scores:
return {
"text": text,
"top_intent": None,
"scores": {},
"latency_ms": round((time.time() - start) * 1000, 2),
}
scores = softmax(scores)
top_intent = max(scores, key=scores.get)
return {
"text": text,
"top_intent": top_intent,
"scores": dict(sorted(scores.items(), key=lambda x: -x[1])),
"latency_ms": round((time.time() - start) * 1000, 2),
}
# ============================================================
# DEFAULT LABELS
# ============================================================
DEFAULT_LABELS = {
"chat": [
"say hello",
"casual talk",
"how are you"
],
"image_generation": [
"generate an image",
"draw a picture",
"create artwork"
],
"action": [
"set a timer",
"create a reminder"
],
"code": [
"write code",
"debug program"
],
"search": [
"search online",
"find information"
]
}
# ============================================================
# GRADIO UI
# ============================================================
with gr.Blocks(title="⚡ Semantic Intent Router") as demo:
gr.Markdown(
"# ⚡ Semantic Intent Router\n"
"MiniLM semantic classifier · No training · Sub-second\n\n"
"• Edit labels freely\n"
"• Add examples per label\n"
"• Used for MPC / system-prompt routing\n"
)
user_input = gr.Textbox(
label="User Input",
placeholder="Type anything…",
lines=2
)
labels_input = gr.JSON(
label="Labels & Examples (editable)",
value=DEFAULT_LABELS
)
output = gr.JSON(label="Routing Result")
classify_btn = gr.Button("Classify", variant="primary")
classify_btn.click(
fn=classify,
inputs=[user_input, labels_input],
outputs=output
)
gr.Markdown(
"### API Usage\n"
"POST to this Space endpoint with:\n\n"
"`{\"data\": [\"your text\", {\"label\": [\"example\"]}]}`\n"
)
# ============================================================
# LAUNCH
# ============================================================
if __name__ == "__main__":
demo.launch(
share=True
)