File size: 10,458 Bytes
dcafbca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 | import gradio as gr
import torch
import logging
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Configure professional logging
logging.basicConfig(format='%(asctime)s | %(levelname)s | %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
class MagicSupportClassifier:
"""
Encapsulates the customer support intent classification model.
Engineered for dynamic label resolution and rapid inference.
"""
def __init__(self, model_id: str = "learn-abc/magicSupport-intent-classifier"):
self.model_id = model_id
self.max_length = 128
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._load_model()
def _load_model(self):
logger.info(f"Initializing model {self.model_id} on {self.device}...")
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id)
self.model.to(self.device)
self.model.eval()
# Extract number of classes dynamically
self.num_classes = len(self.model.config.id2label)
logger.info(f"Model loaded successfully with {self.num_classes} intent classes.")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
def _get_iconography(self, label: str) -> str:
"""
Dynamically assigns UI icons based on intent keywords.
Future-proofs the application against retrained label sets.
"""
label_lower = label.lower()
if "order" in label_lower or "delivery" in label_lower or "track" in label_lower:
return "π¦"
if "refund" in label_lower or "payment" in label_lower or "invoice" in label_lower or "fee" in label_lower:
return "π³"
if "account" in label_lower or "password" in label_lower or "register" in label_lower or "profile" in label_lower:
return "π€"
if "cancel" in label_lower or "delete" in label_lower or "problem" in label_lower or "issue" in label_lower:
return "β οΈ"
if "contact" in label_lower or "service" in label_lower or "support" in label_lower:
return "π§"
return "πΉ"
def _format_label(self, label: str) -> str:
"""Cleans up raw dataset labels for professional UI presentation."""
return label.replace("_", " ").title()
@torch.inference_mode()
def predict(self, text: str, top_k: int = 5):
if not text or not text.strip():
return "<div style='color: #ef4444; padding: 10px;'>β οΈ <b>Input Required:</b> Please enter a customer query.</div>", None
try:
inputs = self.tokenizer(
text.strip(),
return_tensors="pt",
truncation=True,
max_length=self.max_length,
padding=True
).to(self.device)
logits = self.model(**inputs).logits
probs = F.softmax(logits, dim=-1).squeeze()
if probs.dim() == 0:
probs = probs.unsqueeze(0)
# Cap top_k to the maximum number of available classes
actual_top_k = min(top_k, self.num_classes)
top_indices = torch.topk(probs, k=actual_top_k).indices.tolist()
top_probs = torch.topk(probs, k=actual_top_k).values.tolist()
id2label = self.model.config.id2label
# Primary Prediction Formatting
top_intent_raw = id2label[top_indices[0]]
emoji = self._get_iconography(top_intent_raw)
clean_label = self._format_label(top_intent_raw)
confidence = top_probs[0] * 100
result_html = f"""
<h2 style='margin-bottom: 5px; display: flex; align-items: center; gap: 8px;'>{emoji} {clean_label}</h2>
<p style='margin-top: 0; font-size: 16px;'><b>Confidence:</b> {confidence:.1f}%</p>
<hr style='border-top: 1px solid var(--border-color-primary); margin: 20px 0;'/>
<h3 style='margin-bottom: 15px;'>π Top {actual_top_k} Predictions</h3>
"""
# HTML Progress Bars
for idx, prob in zip(top_indices, top_probs):
intent_raw = id2label[idx]
e = self._get_iconography(intent_raw)
l = self._format_label(intent_raw)
pct = prob * 100
bar_html = f"""
<div style="margin-bottom: 16px;">
<div style="display: flex; justify-content: space-between; margin-bottom: 4px;">
<strong>{e} {l}</strong>
<span style="font-family:monospace;">{pct:.1f}%</span>
</div>
<div style="background-color: var(--background-fill-secondary); border: 1px solid var(--border-color-primary); border-radius: 6px; width: 100%; height: 10px;">
<div style="background-color: #8b5cf6; width: {pct}%; height: 100%; border-radius: 5px;"></div>
</div>
</div>
"""
result_html += bar_html
# Format data for the full distribution chart
chart_data = {
self._format_label(id2label[i]): float(probs[i].item())
for i in range(len(probs))
}
return result_html, chart_data
except Exception as e:
logger.error(f"Inference error: {e}")
return f"<div style='color: #ef4444;'>β <b>System Error:</b> Inference failed. Check application logs.</div>", None
# Initialize application backend
app_backend = MagicSupportClassifier()
# High-value test scenarios based on Bitext taxonomy
EXAMPLES = [
["I need to cancel my order immediately, it was placed by mistake.", 5],
["Where can I find the invoice for my last purchase?", 3],
["The item arrived damaged and I want a full refund.", 5],
["How do I change the shipping address on my account?", 3],
["I forgot my password and cannot log in.", 3],
["Are there any hidden fees if I cancel my subscription now?", 5],
]
# Build Gradio Interface
with gr.Blocks(
theme=gr.themes.Soft(primary_hue="violet", secondary_hue="slate"),
title="MagicSupport Intent Classifier R&D Dashboard",
css="""
.header-box { text-align: center; padding: 25px; background: var(--background-fill-secondary); border-radius: 10px; border: 1px solid var(--border-color-primary); margin-bottom: 20px;}
.header-box h1 { color: var(--body-text-color); margin-bottom: 5px; }
.header-box p { color: var(--body-text-color-subdued); font-size: 16px; margin-top: 0; }
.badge { display: inline-block; padding: 4px 12px; border-radius: 12px; font-size: 13px; font-weight: 600; margin: 4px; }
.domain-badge { background: #ede9fe; color: #5b21b6; border: 1px solid #ddd6fe;}
.metric-badge { background: #f1f5f9; color: #334155; border: 1px solid #cbd5e1;}
footer { display: none !important; }
"""
) as demo:
gr.HTML("""
<div class="header-box">
<h1>π§ MagicSupport Intent Classifier</h1>
<p>
High-precision semantic routing for automated customer support pipelines.
</p>
<div style="margin-top:12px;">
<span class="badge domain-badge">E-commerce & Retail</span>
<span class="badge domain-badge">Account Management</span>
<span class="badge domain-badge">Billing & Refunds</span>
<span class="badge metric-badge">Based on Bitext Taxonomy</span>
</div>
</div>
""")
with gr.Row():
with gr.Column(scale=5):
text_input = gr.Textbox(
label="Input Customer Query",
placeholder="Type a customer message here (e.g., 'Where is my package?')...",
lines=3,
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=1, maximum=15, value=5, step=1,
label="Display Top-K Predictions"
)
with gr.Row():
predict_btn = gr.Button("π Execute Prediction", variant="primary")
clear_btn = gr.Button("ποΈ Clear Interface", variant="secondary")
gr.Examples(
examples=EXAMPLES,
inputs=[text_input, top_k_slider],
label="Actionable Test Scenarios",
examples_per_page=6,
)
with gr.Column(scale=5):
result_output = gr.HTML(label="Inference Results")
with gr.Row():
chart_output = gr.Label(
label="Full Semantic Distribution Map",
num_top_classes=app_backend.num_classes # Dynamically set based on model config
)
with gr.Accordion("βοΈ Technical Architecture & Model Details", open=False):
gr.Markdown("""
### Core Specifications
* **Target Model:** `learn-abc/magicSupport-intent-classifier`
* **Objective:** Multi-class text sequence classification for customer support routing.
* **Dataset Lineage:** Trained on the comprehensive `bitext/Bitext-customer-support-llm-chatbot-training-dataset`.
### Pipeline Features
* **Dynamic Label Resolution:** The UI heuristic engine automatically maps raw dataset labels (e.g., `change_shipping_address`) into clean, professional UI elements (e.g., Change Shipping Address) and assigns contextual iconography.
* **Optimized Inference:** Utilizes PyTorch `inference_mode` for reduced memory footprint and accelerated compute during forward passes.
""")
# Event Wiring
predict_btn.click(
fn=app_backend.predict,
inputs=[text_input, top_k_slider],
outputs=[result_output, chart_output],
)
text_input.submit(
fn=app_backend.predict,
inputs=[text_input, top_k_slider],
outputs=[result_output, chart_output],
)
clear_btn.click(
fn=lambda: ("", 5, "", None),
outputs=[text_input, top_k_slider, result_output, chart_output],
)
if __name__ == "__main__":
demo.launch() |