hellofoysal101's picture
Update app.py
0cbecef verified
import gradio as gr
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer
import numpy as np
import pandas as pd
import torch
from lime.lime_text import LimeTextExplainer
from collections import OrderedDict
from PIL import Image, ImageDraw, ImageFont
# ১. ONNX মডেল ও টোকেনাইজার লোড (onnx_model ফোল্ডার থেকে)
model_id = "onnx_model"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = ORTModelForSequenceClassification.from_pretrained(model_id, file_name="model.onnx")
class_names = ['Negative', 'Neutral', 'Positive']
# ২. প্রেডিক্টর ফাংশন (Internal use for LIME and Real-time)
def predictor(texts):
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=64)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
return probs.cpu().numpy()
# ৩. Pillow গ্রাফ তৈরির ফাংশন (ব্যাখ্যা দেখুন বাটনের জন্য)
def create_custom_plot(lime_weights, label_name):
width, height = 850, 650
img = Image.new('RGB', (width, height), color=(255, 255, 255))
draw = ImageDraw.Draw(img)
try:
font_main = ImageFont.truetype("SolaimanLipi.ttf", 20)
font_val = ImageFont.truetype("SolaimanLipi.ttf", 16)
font_title = ImageFont.truetype("SolaimanLipi.ttf", 30)
except:
font_main = ImageFont.load_default()
font_val = ImageFont.load_default()
font_title = ImageFont.load_default()
draw.text((width//2 - 120, 20), f"মডেলের ব্যাখ্যা: {label_name}", fill=(0, 0, 0), font=font_title)
y_start = 100
bar_height = 35
center_x = 400
max_val = max(lime_weights['weights'].abs().max(), 0.1)
scale = 250 / max_val
for i, row in lime_weights.iterrows():
word = row['words']
weight = row['weights']
val_text = f"{weight:.3f}"
if weight > 0:
x0, x1 = center_x, center_x + (weight * scale)
color = (46, 204, 113)
text_pos_x = center_x - 180
val_pos_x = x1 + 10
else:
x0, x1 = center_x + (weight * scale), center_x
color = (231, 76, 60)
text_pos_x = center_x + 20
val_pos_x = x0 - 55
draw.rectangle([x0, y_start, x1, y_start + bar_height], fill=color, outline=(0,0,0))
draw.text((text_pos_x, y_start + 5), word, fill=(0, 0, 0), font=font_main)
draw.text((val_pos_x, y_start + 8), val_text, fill=(50, 50, 50), font=font_val)
y_start += 45
draw.line([center_x, 80, center_x, 600], fill=(0, 0, 0), width=3)
return img
# ৪. ফাংশন ১: রিয়েল-টাইম প্রেডিকশন
def get_prediction(text):
if not text or not text.strip():
return {} # খালি ডিকশনারি
probs = predictor([text])[0]
confidences = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
return confidences
# ৫. ফাংশন ২: LIME ব্যাখ্যা (আলাদা বাটন ক্লিকের জন্য)
def get_explanation(text):
if not text or not text.strip():
return None
probs = predictor([text])[0]
label_idx = np.argmax(probs)
label_name = class_names[label_idx]
explainer = LimeTextExplainer(class_names=class_names, split_expression=r'\s+')
# লাইভ মোডে যাতে স্লো না হয়, তাই স্যাম্পল সংখ্যা ১০০০ রাখা হয়েছে
exp = explainer.explain_instance(text, predictor, num_features=10, labels=(label_idx,), num_samples=1000)
weights = OrderedDict(exp.as_list(label=label_idx))
lime_weights = pd.DataFrame({'words': list(weights.keys()), 'weights': list(weights.values())})
chart_img = create_custom_plot(lime_weights, label_name)
return chart_img
# ৬. রিস্টার্ট বা ক্লিয়ার করার ফাংশন
def clear_inputs():
return "", {}, None
# ৭. Gradio Layout (Live Update Configuration)
with gr.Blocks() as demo:
gr.Markdown("# Real-time Bangla Political Sentiment Analysis with XAI")
gr.Markdown("")
with gr.Row():
with gr.Column():
all_conf_out = gr.Label(label="প্রধান ফলাফল ও ক্লাসভিত্তিক কনফিডেন্স স্কোর", num_top_classes=3)
input_text = gr.Textbox(
label="বাংলা টেক্সট লিখুন",
lines=5,
placeholder="এখানে আপনার মন্তব্যটি টাইপ করুন...",
interactive=True
)
# ব্যাখ্যা দেখার বাটনটি আগের মতোই আলাদা থাকবে
explain_btn = gr.Button("বিস্তারিত ব্যাখ্যা (LIME) দেখুন", variant="secondary")
clear_btn = gr.Button("সব মুছে ফেলুন (Reset)", variant="stop") # লাল রঙের বাটন
with gr.Row():
plot_out = gr.Image(label="XAI (LIME) Visualization")
# [ম্যাজিক লাইন] - টাইপ করার সাথে সাথে প্রেডিকশন হবে
input_text.change(
fn=get_prediction,
inputs=input_text,
outputs=all_conf_out,
show_progress="hidden" # টাইপিং এর সময় লোডিং বার দেখাবে না যাতে ডিস্টার্ব না হয়
)
# ব্যাখ্যা বাটন ক্লিক করলে গ্রাফ আসবে
explain_btn.click(get_explanation, inputs=input_text, outputs=plot_out)
# [রিসেট বাটন লজিক]
clear_btn.click(
fn=clear_inputs,
inputs=[],
outputs=[input_text, all_conf_out, plot_out]
)
demo.launch(theme=gr.themes.Soft())