| import gradio as gr |
| from optimum.onnxruntime import ORTModelForSequenceClassification |
| from transformers import AutoTokenizer |
| import numpy as np |
| import pandas as pd |
| import torch |
| from lime.lime_text import LimeTextExplainer |
| from collections import OrderedDict |
| from PIL import Image, ImageDraw, ImageFont |
|
|
| |
| model_id = "onnx_model" |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| model = ORTModelForSequenceClassification.from_pretrained(model_id, file_name="model.onnx") |
|
|
| class_names = ['Negative', 'Neutral', 'Positive'] |
|
|
| |
| def predictor(texts): |
| inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=64) |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) |
| return probs.cpu().numpy() |
|
|
| |
| def create_custom_plot(lime_weights, label_name): |
| width, height = 850, 650 |
| img = Image.new('RGB', (width, height), color=(255, 255, 255)) |
| draw = ImageDraw.Draw(img) |
| |
| try: |
| font_main = ImageFont.truetype("SolaimanLipi.ttf", 20) |
| font_val = ImageFont.truetype("SolaimanLipi.ttf", 16) |
| font_title = ImageFont.truetype("SolaimanLipi.ttf", 30) |
| except: |
| font_main = ImageFont.load_default() |
| font_val = ImageFont.load_default() |
| font_title = ImageFont.load_default() |
|
|
| draw.text((width//2 - 120, 20), f"মডেলের ব্যাখ্যা: {label_name}", fill=(0, 0, 0), font=font_title) |
| |
| y_start = 100 |
| bar_height = 35 |
| center_x = 400 |
| max_val = max(lime_weights['weights'].abs().max(), 0.1) |
| scale = 250 / max_val |
|
|
| for i, row in lime_weights.iterrows(): |
| word = row['words'] |
| weight = row['weights'] |
| val_text = f"{weight:.3f}" |
| |
| if weight > 0: |
| x0, x1 = center_x, center_x + (weight * scale) |
| color = (46, 204, 113) |
| text_pos_x = center_x - 180 |
| val_pos_x = x1 + 10 |
| else: |
| x0, x1 = center_x + (weight * scale), center_x |
| color = (231, 76, 60) |
| text_pos_x = center_x + 20 |
| val_pos_x = x0 - 55 |
| |
| draw.rectangle([x0, y_start, x1, y_start + bar_height], fill=color, outline=(0,0,0)) |
| draw.text((text_pos_x, y_start + 5), word, fill=(0, 0, 0), font=font_main) |
| draw.text((val_pos_x, y_start + 8), val_text, fill=(50, 50, 50), font=font_val) |
| y_start += 45 |
|
|
| draw.line([center_x, 80, center_x, 600], fill=(0, 0, 0), width=3) |
| return img |
|
|
| |
| def get_prediction(text): |
| if not text or not text.strip(): |
| return {} |
| |
| probs = predictor([text])[0] |
| confidences = {class_names[i]: float(probs[i]) for i in range(len(class_names))} |
| return confidences |
|
|
| |
| def get_explanation(text): |
| if not text or not text.strip(): |
| return None |
| |
| probs = predictor([text])[0] |
| label_idx = np.argmax(probs) |
| label_name = class_names[label_idx] |
|
|
| explainer = LimeTextExplainer(class_names=class_names, split_expression=r'\s+') |
| |
| exp = explainer.explain_instance(text, predictor, num_features=10, labels=(label_idx,), num_samples=1000) |
| weights = OrderedDict(exp.as_list(label=label_idx)) |
| |
| lime_weights = pd.DataFrame({'words': list(weights.keys()), 'weights': list(weights.values())}) |
| chart_img = create_custom_plot(lime_weights, label_name) |
| return chart_img |
|
|
|
|
| |
| def clear_inputs(): |
| return "", {}, None |
|
|
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# Real-time Bangla Political Sentiment Analysis with XAI") |
| gr.Markdown("") |
| |
| with gr.Row(): |
| with gr.Column(): |
| all_conf_out = gr.Label(label="প্রধান ফলাফল ও ক্লাসভিত্তিক কনফিডেন্স স্কোর", num_top_classes=3) |
| |
| input_text = gr.Textbox( |
| label="বাংলা টেক্সট লিখুন", |
| lines=5, |
| placeholder="এখানে আপনার মন্তব্যটি টাইপ করুন...", |
| interactive=True |
| ) |
| |
| explain_btn = gr.Button("বিস্তারিত ব্যাখ্যা (LIME) দেখুন", variant="secondary") |
| clear_btn = gr.Button("সব মুছে ফেলুন (Reset)", variant="stop") |
| |
| |
| with gr.Row(): |
| plot_out = gr.Image(label="XAI (LIME) Visualization") |
|
|
| |
| input_text.change( |
| fn=get_prediction, |
| inputs=input_text, |
| outputs=all_conf_out, |
| show_progress="hidden" |
| ) |
| |
| |
| explain_btn.click(get_explanation, inputs=input_text, outputs=plot_out) |
|
|
| |
| clear_btn.click( |
| fn=clear_inputs, |
| inputs=[], |
| outputs=[input_text, all_conf_out, plot_out] |
| ) |
|
|
| demo.launch(theme=gr.themes.Soft()) |