| import gradio as gr |
| import torch |
| import logging |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
| |
| logging.basicConfig(format='%(asctime)s | %(levelname)s | %(message)s', level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class BankingIntentClassifier: |
| """ |
| Encapsulates the intent classification model for scalable deployment. |
| """ |
| def __init__(self, model_id: str = "learn-abc/banking-multilingual-intent-classifier"): |
| self.model_id = model_id |
| self.max_length = 64 |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| self.intent_info = { |
| "ACCOUNT_INFO": ("🏦", "Account Information"), |
| "ATM_SUPPORT": ("🏧", "ATM Support"), |
| "CARD_ISSUE": ("💳", "Card Issue"), |
| "CARD_MANAGEMENT": ("⚙️", "Card Management"), |
| "CARD_REPLACEMENT": ("🔄", "Card Replacement"), |
| "CHECK_BALANCE": ("💰", "Check Balance"), |
| "EDIT_PERSONAL_DETAILS": ("✏️", "Edit Personal Details"), |
| "FAILED_TRANSFER": ("❌", "Failed Transfer"), |
| "FALLBACK": ("🤔", "Out of Scope / Fallback"), |
| "FEES": ("📋", "Fees & Charges"), |
| "GREETING": ("👋", "Greeting"), |
| "LOST_OR_STOLEN_CARD": ("🚨", "Lost or Stolen Card"), |
| "MINI_STATEMENT": ("📄", "Mini Statement"), |
| "TRANSFER": ("💸", "Transfer"), |
| } |
| self._load_model() |
|
|
| def _load_model(self): |
| logger.info(f"Initializing model {self.model_id} on {self.device}...") |
| try: |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) |
| |
| |
| self.model = AutoModelForSequenceClassification.from_pretrained( |
| self.model_id, |
| torch_dtype=torch.float16 |
| ) |
| |
| self.model.to(self.device) |
| self.model.eval() |
| logger.info("Model loaded successfully.") |
| except Exception as e: |
| logger.error(f"Failed to load model: {e}") |
| raise |
|
|
| @torch.inference_mode() |
| def predict(self, text: str, top_k: int = 3): |
| if not text or not text.strip(): |
| return "<div style='color: red; padding: 10px;'>⚠️ <b>Input Required:</b> Please enter a banking query.</div>", None |
|
|
| try: |
| inputs = self.tokenizer( |
| text.strip(), |
| return_tensors="pt", |
| truncation=True, |
| max_length=self.max_length, |
| padding=True |
| ).to(self.device) |
|
|
| logits = self.model(**inputs).logits |
| probs = F.softmax(logits, dim=-1).squeeze() |
| |
| |
| if probs.dim() == 0: |
| probs = probs.unsqueeze(0) |
|
|
| top_indices = torch.topk(probs, k=top_k).indices.tolist() |
| top_probs = torch.topk(probs, k=top_k).values.tolist() |
| id2label = self.model.config.id2label |
|
|
| |
| top_intent = id2label[top_indices[0]] |
| emoji, label = self.intent_info.get(top_intent, ("🔹", top_intent)) |
| confidence = top_probs[0] * 100 |
|
|
| result_html = f""" |
| <h2 style='margin-bottom: 5px; display: flex; align-items: center; gap: 8px;'>{emoji} {label}</h2> |
| <p style='margin-top: 0; font-size: 16px;'><b>Confidence:</b> {confidence:.1f}%</p> |
| <hr style='border-top: 1px solid var(--border-color-primary); margin: 20px 0;'/> |
| <h3 style='margin-bottom: 15px;'>📊 Top {top_k} Predictions</h3> |
| """ |
| |
| |
| for idx, prob in zip(top_indices, top_probs): |
| intent = id2label[idx] |
| e, l = self.intent_info.get(intent, ("🔹", intent)) |
| pct = prob * 100 |
| bar_html = f""" |
| <div style="margin-bottom: 16px;"> |
| <div style="display: flex; justify-content: space-between; margin-bottom: 4px;"> |
| <strong>{e} {l}</strong> |
| <span style="font-family:monospace;">{pct:.1f}%</span> |
| </div> |
| <div style="background-color: var(--background-fill-secondary); border: 1px solid var(--border-color-primary); border-radius: 6px; width: 100%; height: 10px;"> |
| <div style="background-color: #3b82f6; width: {pct}%; height: 100%; border-radius: 5px;"></div> |
| </div> |
| </div> |
| """ |
| result_html += bar_html |
|
|
| |
| chart_data = { |
| id2label[i]: float(probs[i].item()) |
| for i in range(len(probs)) |
| } |
| |
| return result_html, chart_data |
|
|
| except Exception as e: |
| logger.error(f"Inference error: {e}") |
| return f"<div style='color: red;'>❌ <b>Error:</b> An internal error occurred during processing. Check logs.</div>", None |
|
|
| |
| app_backend = BankingIntentClassifier() |
|
|
| |
| EXAMPLES = [ |
| ["can you check how much money I have in my account?", "English"], |
| ["আমার ব্যালেন্স কত?", "Bangla"], |
| ["amar card hariye geche, block koro", "Banglish"], |
| ["last 10 transaction dekhao", "Code-mixed"], |
| ["ATM থেকে টাকা বের হচ্ছে না", "Bangla"], |
| ["my transfer failed, what happened?", "English"], |
| ["fee keno kete nilo without reason", "Code-mixed"], |
| ["hello, I need some help", "English"], |
| ["who won the cricket match yesterday?", "Fallback"], |
| ["card ta replace korte chai", "Banglish"], |
| ["আমার পিন পরিবর্তন করতে চাই", "Bangla"], |
| ["I lost my debit card, please block it immediately", "English"], |
| ] |
|
|
| |
| with gr.Blocks( |
| theme=gr.themes.Soft(primary_hue="blue", secondary_hue="slate"), |
| title="Banking Intent Classifier R&D Dashboard", |
| css=""" |
| .header-box { text-align: center; padding: 25px; background: var(--background-fill-secondary); border-radius: 10px; border: 1px solid var(--border-color-primary); margin-bottom: 20px;} |
| .header-box h1 { color: var(--body-text-color); margin-bottom: 5px; } |
| .header-box p { color: var(--body-text-color-subdued); font-size: 16px; margin-top: 0; } |
| .badge { display: inline-block; padding: 4px 12px; border-radius: 12px; font-size: 13px; font-weight: 600; margin: 4px; } |
| .lang-badge { background: #e0f2fe; color: #0369a1; border: 1px solid #bae6fd;} |
| .metric-badge { background: #dcfce7; color: #15803d; border: 1px solid #bbf7d0;} |
| footer { display: none !important; } |
| """ |
| ) as demo: |
|
|
| gr.HTML(""" |
| <div class="header-box"> |
| <h1>🏦 Multilingual Banking Intent Classifier</h1> |
| <p> |
| Advanced fine-tuned classification for 14 production banking intents. |
| </p> |
| <div style="margin-top:12px;"> |
| <span class="badge lang-badge">🇬🇧 English</span> |
| <span class="badge lang-badge">🇧🇩 Bangla</span> |
| <span class="badge lang-badge">✍️ Banglish</span> |
| <span class="badge lang-badge">🔀 Code-mixed</span> |
| <span class="badge metric-badge">✅ Acc: ~98.4%</span> |
| <span class="badge metric-badge">📊 F1: ~0.982</span> |
| </div> |
| </div> |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(scale=5): |
| text_input = gr.Textbox( |
| label="Input Banking Query", |
| placeholder="Type a query here (e.g., 'check my balance' or 'amar card block koro')...", |
| lines=3, |
| ) |
| |
| with gr.Row(): |
| top_k_slider = gr.Slider( |
| minimum=1, maximum=14, value=3, step=1, |
| label="Display Top-K Predictions" |
| ) |
| |
| with gr.Row(): |
| predict_btn = gr.Button("🔍 Execute Prediction", variant="primary") |
| clear_btn = gr.Button("🗑️ Clear Interface", variant="secondary") |
| |
| gr.Examples( |
| examples=[[e[0]] for e in EXAMPLES], |
| inputs=text_input, |
| label="Test Scenarios", |
| examples_per_page=4, |
| ) |
|
|
| with gr.Column(scale=5): |
| result_output = gr.HTML(label="Inference Results") |
|
|
| |
| with gr.Row(): |
| chart_output = gr.Label( |
| label="Full Distribution Map", |
| num_top_classes=14 |
| ) |
|
|
| with gr.Accordion("⚙️ Technical Architecture & Model Details", open=False): |
| gr.Markdown(""" |
| ### Core Specifications |
| * **Base Architecture:** `google/muril-base-cased` (Optimized for Indian vernaculars). |
| * **Objective:** Multi-class text sequence classification. |
| * **Supported Lexicons:** English, Standard Bengali, Romanized Bengali (Banglish), Code-mixed syntax. |
| |
| ### Intent Taxonomy |
| | ID | Description | |
| |---|---| |
| | ACCOUNT_INFO | General account status and balance queries | |
| | ATM_SUPPORT | Hardware issues and withdrawal failures | |
| | CARD_ISSUE | Malfunctions and terminal declines | |
| | CARD_MANAGEMENT | PIN resets, activation, and security controls | |
| | CARD_REPLACEMENT | Physical card replacement requests | |
| | CHECK_BALANCE | Direct balance inquiries | |
| | EDIT_PERSONAL_DETAILS | KYC and profile updates | |
| | FAILED_TRANSFER | Routing and transaction failures | |
| | FALLBACK | Non-banking or ambiguous queries | |
| | FEES | Service charge inquiries | |
| | GREETING | Conversational anchors | |
| | LOST_OR_STOLEN_CARD | Emergency block requests | |
| | MINI_STATEMENT | Ledger and transaction history | |
| | TRANSFER | Fund routing instructions | |
| |
| ### Data Lineage |
| Derived from the `learn-abc/banking14-intents-en-bn-banglish` dataset. The original 77 intents were aggressively condensed into 14 high-value production categories. The corpus was subsequently augmented with synthetically generated multilingual data translated via advanced neural machine translation models. |
| """) |
|
|
| |
| predict_btn.click( |
| fn=app_backend.predict, |
| inputs=[text_input, top_k_slider], |
| outputs=[result_output, chart_output], |
| ) |
| text_input.submit( |
| fn=app_backend.predict, |
| inputs=[text_input, top_k_slider], |
| outputs=[result_output, chart_output], |
| ) |
| clear_btn.click( |
| fn=lambda: ("", "", None), |
| outputs=[text_input, result_output, chart_output], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |