File size: 10,458 Bytes
dcafbca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import gradio as gr
import torch
import logging
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Configure professional logging
logging.basicConfig(format='%(asctime)s | %(levelname)s | %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)

class MagicSupportClassifier:
    """
    Encapsulates the customer support intent classification model.
    Engineered for dynamic label resolution and rapid inference.
    """
    def __init__(self, model_id: str = "learn-abc/magicSupport-intent-classifier"):
        self.model_id = model_id
        self.max_length = 128
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._load_model()

    def _load_model(self):
        logger.info(f"Initializing model {self.model_id} on {self.device}...")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
            self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id)
            self.model.to(self.device)
            self.model.eval()
            
            # Extract number of classes dynamically
            self.num_classes = len(self.model.config.id2label)
            logger.info(f"Model loaded successfully with {self.num_classes} intent classes.")
        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            raise

    def _get_iconography(self, label: str) -> str:
        """
        Dynamically assigns UI icons based on intent keywords.
        Future-proofs the application against retrained label sets.
        """
        label_lower = label.lower()
        if "order" in label_lower or "delivery" in label_lower or "track" in label_lower:
            return "πŸ“¦"
        if "refund" in label_lower or "payment" in label_lower or "invoice" in label_lower or "fee" in label_lower:
            return "πŸ’³"
        if "account" in label_lower or "password" in label_lower or "register" in label_lower or "profile" in label_lower:
            return "πŸ‘€"
        if "cancel" in label_lower or "delete" in label_lower or "problem" in label_lower or "issue" in label_lower:
            return "⚠️"
        if "contact" in label_lower or "service" in label_lower or "support" in label_lower:
            return "🎧"
        return "πŸ”Ή"

    def _format_label(self, label: str) -> str:
        """Cleans up raw dataset labels for professional UI presentation."""
        return label.replace("_", " ").title()

    @torch.inference_mode()
    def predict(self, text: str, top_k: int = 5):
        if not text or not text.strip():
            return "<div style='color: #ef4444; padding: 10px;'>⚠️ <b>Input Required:</b> Please enter a customer query.</div>", None

        try:
            inputs = self.tokenizer(
                text.strip(),
                return_tensors="pt",
                truncation=True,
                max_length=self.max_length,
                padding=True
            ).to(self.device)

            logits = self.model(**inputs).logits
            probs = F.softmax(logits, dim=-1).squeeze()
            
            if probs.dim() == 0:
                probs = probs.unsqueeze(0)

            # Cap top_k to the maximum number of available classes
            actual_top_k = min(top_k, self.num_classes)
            
            top_indices = torch.topk(probs, k=actual_top_k).indices.tolist()
            top_probs = torch.topk(probs, k=actual_top_k).values.tolist()
            id2label = self.model.config.id2label

            # Primary Prediction Formatting
            top_intent_raw = id2label[top_indices[0]]
            emoji = self._get_iconography(top_intent_raw)
            clean_label = self._format_label(top_intent_raw)
            confidence = top_probs[0] * 100

            result_html = f"""
            <h2 style='margin-bottom: 5px; display: flex; align-items: center; gap: 8px;'>{emoji} {clean_label}</h2>
            <p style='margin-top: 0; font-size: 16px;'><b>Confidence:</b> {confidence:.1f}%</p>
            <hr style='border-top: 1px solid var(--border-color-primary); margin: 20px 0;'/>
            <h3 style='margin-bottom: 15px;'>πŸ“Š Top {actual_top_k} Predictions</h3>
            """
            
            # HTML Progress Bars
            for idx, prob in zip(top_indices, top_probs):
                intent_raw = id2label[idx]
                e = self._get_iconography(intent_raw)
                l = self._format_label(intent_raw)
                pct = prob * 100
                
                bar_html = f"""
                <div style="margin-bottom: 16px;">
                    <div style="display: flex; justify-content: space-between; margin-bottom: 4px;">
                        <strong>{e} {l}</strong> 
                        <span style="font-family:monospace;">{pct:.1f}%</span>
                    </div>
                    <div style="background-color: var(--background-fill-secondary); border: 1px solid var(--border-color-primary); border-radius: 6px; width: 100%; height: 10px;">
                        <div style="background-color: #8b5cf6; width: {pct}%; height: 100%; border-radius: 5px;"></div>
                    </div>
                </div>
                """
                result_html += bar_html

            # Format data for the full distribution chart
            chart_data = {
                self._format_label(id2label[i]): float(probs[i].item())
                for i in range(len(probs))
            }
            
            return result_html, chart_data

        except Exception as e:
            logger.error(f"Inference error: {e}")
            return f"<div style='color: #ef4444;'>❌ <b>System Error:</b> Inference failed. Check application logs.</div>", None

# Initialize application backend
app_backend = MagicSupportClassifier()

# High-value test scenarios based on Bitext taxonomy
EXAMPLES = [
    ["I need to cancel my order immediately, it was placed by mistake.", 5],
    ["Where can I find the invoice for my last purchase?", 3],
    ["The item arrived damaged and I want a full refund.", 5],
    ["How do I change the shipping address on my account?", 3],
    ["I forgot my password and cannot log in.", 3],
    ["Are there any hidden fees if I cancel my subscription now?", 5],
]

# Build Gradio Interface
with gr.Blocks(
    theme=gr.themes.Soft(primary_hue="violet", secondary_hue="slate"),
    title="MagicSupport Intent Classifier R&D Dashboard",
    css="""
        .header-box { text-align: center; padding: 25px; background: var(--background-fill-secondary); border-radius: 10px; border: 1px solid var(--border-color-primary); margin-bottom: 20px;}
        .header-box h1 { color: var(--body-text-color); margin-bottom: 5px; }
        .header-box p { color: var(--body-text-color-subdued); font-size: 16px; margin-top: 0; }
        .badge { display: inline-block; padding: 4px 12px; border-radius: 12px; font-size: 13px; font-weight: 600; margin: 4px; }
        .domain-badge { background: #ede9fe; color: #5b21b6; border: 1px solid #ddd6fe;}
        .metric-badge { background: #f1f5f9; color: #334155; border: 1px solid #cbd5e1;}
        footer { display: none !important; }
    """
) as demo:

    gr.HTML("""
    <div class="header-box">
        <h1>🎧 MagicSupport Intent Classifier</h1>
        <p>
            High-precision semantic routing for automated customer support pipelines.
        </p>
        <div style="margin-top:12px;">
            <span class="badge domain-badge">E-commerce & Retail</span>
            <span class="badge domain-badge">Account Management</span>
            <span class="badge domain-badge">Billing & Refunds</span>
            <span class="badge metric-badge">Based on Bitext Taxonomy</span>
        </div>
    </div>
    """)

    with gr.Row():
        with gr.Column(scale=5):
            text_input = gr.Textbox(
                label="Input Customer Query",
                placeholder="Type a customer message here (e.g., 'Where is my package?')...",
                lines=3,
            )
            
            with gr.Row():
                top_k_slider = gr.Slider(
                    minimum=1, maximum=15, value=5, step=1,
                    label="Display Top-K Predictions"
                )
            
            with gr.Row():
                predict_btn = gr.Button("πŸ” Execute Prediction", variant="primary")
                clear_btn = gr.Button("πŸ—‘οΈ Clear Interface", variant="secondary")
                
            gr.Examples(
                examples=EXAMPLES,
                inputs=[text_input, top_k_slider],
                label="Actionable Test Scenarios",
                examples_per_page=6,
            )

        with gr.Column(scale=5):
            result_output = gr.HTML(label="Inference Results")

    with gr.Row():
        chart_output = gr.Label(
            label="Full Semantic Distribution Map",
            num_top_classes=app_backend.num_classes  # Dynamically set based on model config
        )

    with gr.Accordion("βš™οΈ Technical Architecture & Model Details", open=False):
        gr.Markdown("""
        ### Core Specifications
        * **Target Model:** `learn-abc/magicSupport-intent-classifier`
        * **Objective:** Multi-class text sequence classification for customer support routing.
        * **Dataset Lineage:** Trained on the comprehensive `bitext/Bitext-customer-support-llm-chatbot-training-dataset`.
        
        ### Pipeline Features
        * **Dynamic Label Resolution:** The UI heuristic engine automatically maps raw dataset labels (e.g., `change_shipping_address`) into clean, professional UI elements (e.g., Change Shipping Address) and assigns contextual iconography.
        * **Optimized Inference:** Utilizes PyTorch `inference_mode` for reduced memory footprint and accelerated compute during forward passes.
        """)

    # Event Wiring
    predict_btn.click(
        fn=app_backend.predict,
        inputs=[text_input, top_k_slider],
        outputs=[result_output, chart_output],
    )
    text_input.submit(
        fn=app_backend.predict,
        inputs=[text_input, top_k_slider],
        outputs=[result_output, chart_output],
    )
    clear_btn.click(
        fn=lambda: ("", 5, "", None),
        outputs=[text_input, top_k_slider, result_output, chart_output],
    )

if __name__ == "__main__":
    demo.launch()