Abhishek Singh
update model loading logic
b2021c8
import gradio as gr
import torch
import logging
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Configure professional logging
logging.basicConfig(format='%(asctime)s | %(levelname)s | %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
class BankingIntentClassifier:
"""
Encapsulates the intent classification model for scalable deployment.
"""
def __init__(self, model_id: str = "learn-abc/banking-multilingual-intent-classifier"):
self.model_id = model_id
self.max_length = 64
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.intent_info = {
"ACCOUNT_INFO": ("🏦", "Account Information"),
"ATM_SUPPORT": ("🏧", "ATM Support"),
"CARD_ISSUE": ("💳", "Card Issue"),
"CARD_MANAGEMENT": ("⚙️", "Card Management"),
"CARD_REPLACEMENT": ("🔄", "Card Replacement"),
"CHECK_BALANCE": ("💰", "Check Balance"),
"EDIT_PERSONAL_DETAILS": ("✏️", "Edit Personal Details"),
"FAILED_TRANSFER": ("❌", "Failed Transfer"),
"FALLBACK": ("🤔", "Out of Scope / Fallback"),
"FEES": ("📋", "Fees & Charges"),
"GREETING": ("👋", "Greeting"),
"LOST_OR_STOLEN_CARD": ("🚨", "Lost or Stolen Card"),
"MINI_STATEMENT": ("📄", "Mini Statement"),
"TRANSFER": ("💸", "Transfer"),
}
self._load_model()
def _load_model(self):
logger.info(f"Initializing model {self.model_id} on {self.device}...")
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
# Critical Memory Optimization: Force float16 to prevent Free Tier OOM kills
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_id,
torch_dtype=torch.float16
)
self.model.to(self.device)
self.model.eval()
logger.info("Model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
@torch.inference_mode()
def predict(self, text: str, top_k: int = 3):
if not text or not text.strip():
return "<div style='color: red; padding: 10px;'>⚠️ <b>Input Required:</b> Please enter a banking query.</div>", None
try:
inputs = self.tokenizer(
text.strip(),
return_tensors="pt",
truncation=True,
max_length=self.max_length,
padding=True
).to(self.device)
logits = self.model(**inputs).logits
probs = F.softmax(logits, dim=-1).squeeze()
# Ensure proper tensor dimensions for single-batch inputs
if probs.dim() == 0:
probs = probs.unsqueeze(0)
top_indices = torch.topk(probs, k=top_k).indices.tolist()
top_probs = torch.topk(probs, k=top_k).values.tolist()
id2label = self.model.config.id2label
# Primary Prediction Formatting using Pure HTML
top_intent = id2label[top_indices[0]]
emoji, label = self.intent_info.get(top_intent, ("🔹", top_intent))
confidence = top_probs[0] * 100
result_html = f"""
<h2 style='margin-bottom: 5px; display: flex; align-items: center; gap: 8px;'>{emoji} {label}</h2>
<p style='margin-top: 0; font-size: 16px;'><b>Confidence:</b> {confidence:.1f}%</p>
<hr style='border-top: 1px solid var(--border-color-primary); margin: 20px 0;'/>
<h3 style='margin-bottom: 15px;'>📊 Top {top_k} Predictions</h3>
"""
# HTML-based progress bars for visual clarity
for idx, prob in zip(top_indices, top_probs):
intent = id2label[idx]
e, l = self.intent_info.get(intent, ("🔹", intent))
pct = prob * 100
bar_html = f"""
<div style="margin-bottom: 16px;">
<div style="display: flex; justify-content: space-between; margin-bottom: 4px;">
<strong>{e} {l}</strong>
<span style="font-family:monospace;">{pct:.1f}%</span>
</div>
<div style="background-color: var(--background-fill-secondary); border: 1px solid var(--border-color-primary); border-radius: 6px; width: 100%; height: 10px;">
<div style="background-color: #3b82f6; width: {pct}%; height: 100%; border-radius: 5px;"></div>
</div>
</div>
"""
result_html += bar_html
# Format data for Gradio Label component
chart_data = {
id2label[i]: float(probs[i].item())
for i in range(len(probs))
}
return result_html, chart_data
except Exception as e:
logger.error(f"Inference error: {e}")
return f"<div style='color: red;'>❌ <b>Error:</b> An internal error occurred during processing. Check logs.</div>", None
# Initialize application backend
app_backend = BankingIntentClassifier()
# Define UI Dataset Examples
EXAMPLES = [
["can you check how much money I have in my account?", "English"],
["আমার ব্যালেন্স কত?", "Bangla"],
["amar card hariye geche, block koro", "Banglish"],
["last 10 transaction dekhao", "Code-mixed"],
["ATM থেকে টাকা বের হচ্ছে না", "Bangla"],
["my transfer failed, what happened?", "English"],
["fee keno kete nilo without reason", "Code-mixed"],
["hello, I need some help", "English"],
["who won the cricket match yesterday?", "Fallback"],
["card ta replace korte chai", "Banglish"],
["আমার পিন পরিবর্তন করতে চাই", "Bangla"],
["I lost my debit card, please block it immediately", "English"],
]
# Build Gradio Interface
with gr.Blocks(
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="slate"),
title="Banking Intent Classifier R&D Dashboard",
css="""
.header-box { text-align: center; padding: 25px; background: var(--background-fill-secondary); border-radius: 10px; border: 1px solid var(--border-color-primary); margin-bottom: 20px;}
.header-box h1 { color: var(--body-text-color); margin-bottom: 5px; }
.header-box p { color: var(--body-text-color-subdued); font-size: 16px; margin-top: 0; }
.badge { display: inline-block; padding: 4px 12px; border-radius: 12px; font-size: 13px; font-weight: 600; margin: 4px; }
.lang-badge { background: #e0f2fe; color: #0369a1; border: 1px solid #bae6fd;}
.metric-badge { background: #dcfce7; color: #15803d; border: 1px solid #bbf7d0;}
footer { display: none !important; }
"""
) as demo:
gr.HTML("""
<div class="header-box">
<h1>🏦 Multilingual Banking Intent Classifier</h1>
<p>
Advanced fine-tuned classification for 14 production banking intents.
</p>
<div style="margin-top:12px;">
<span class="badge lang-badge">🇬🇧 English</span>
<span class="badge lang-badge">🇧🇩 Bangla</span>
<span class="badge lang-badge">✍️ Banglish</span>
<span class="badge lang-badge">🔀 Code-mixed</span>
<span class="badge metric-badge">✅ Acc: ~98.4%</span>
<span class="badge metric-badge">📊 F1: ~0.982</span>
</div>
</div>
""")
with gr.Row():
with gr.Column(scale=5):
text_input = gr.Textbox(
label="Input Banking Query",
placeholder="Type a query here (e.g., 'check my balance' or 'amar card block koro')...",
lines=3,
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=1, maximum=14, value=3, step=1,
label="Display Top-K Predictions"
)
with gr.Row():
predict_btn = gr.Button("🔍 Execute Prediction", variant="primary")
clear_btn = gr.Button("🗑️ Clear Interface", variant="secondary")
gr.Examples(
examples=[[e[0]] for e in EXAMPLES],
inputs=text_input,
label="Test Scenarios",
examples_per_page=4,
)
with gr.Column(scale=5):
result_output = gr.HTML(label="Inference Results")
# Chart output moved to its own full-width row
with gr.Row():
chart_output = gr.Label(
label="Full Distribution Map",
num_top_classes=14
)
with gr.Accordion("⚙️ Technical Architecture & Model Details", open=False):
gr.Markdown("""
### Core Specifications
* **Base Architecture:** `google/muril-base-cased` (Optimized for Indian vernaculars).
* **Objective:** Multi-class text sequence classification.
* **Supported Lexicons:** English, Standard Bengali, Romanized Bengali (Banglish), Code-mixed syntax.
### Intent Taxonomy
| ID | Description |
|---|---|
| ACCOUNT_INFO | General account status and balance queries |
| ATM_SUPPORT | Hardware issues and withdrawal failures |
| CARD_ISSUE | Malfunctions and terminal declines |
| CARD_MANAGEMENT | PIN resets, activation, and security controls |
| CARD_REPLACEMENT | Physical card replacement requests |
| CHECK_BALANCE | Direct balance inquiries |
| EDIT_PERSONAL_DETAILS | KYC and profile updates |
| FAILED_TRANSFER | Routing and transaction failures |
| FALLBACK | Non-banking or ambiguous queries |
| FEES | Service charge inquiries |
| GREETING | Conversational anchors |
| LOST_OR_STOLEN_CARD | Emergency block requests |
| MINI_STATEMENT | Ledger and transaction history |
| TRANSFER | Fund routing instructions |
### Data Lineage
Derived from the `learn-abc/banking14-intents-en-bn-banglish` dataset. The original 77 intents were aggressively condensed into 14 high-value production categories. The corpus was subsequently augmented with synthetically generated multilingual data translated via advanced neural machine translation models.
""")
# Event Wiring
predict_btn.click(
fn=app_backend.predict,
inputs=[text_input, top_k_slider],
outputs=[result_output, chart_output],
)
text_input.submit(
fn=app_backend.predict,
inputs=[text_input, top_k_slider],
outputs=[result_output, chart_output],
)
clear_btn.click(
fn=lambda: ("", "", None),
outputs=[text_input, result_output, chart_output],
)
if __name__ == "__main__":
demo.launch()