Subha95's picture
Update app.py
1e1a05b verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import re
# Load Bengali sentiment model (multi-label)
model_name = "Subha95/bengali-sentiment-model"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Map labels to human-readable text + emoji
label_map = {
"LABEL_0": "Assertive 💬",
"LABEL_1": "Negative 🙁",
"LABEL_2": "Doubtful 🤔",
"LABEL_3": "Happiness 🙂",
"LABEL_4": "Sadness 😢",
}
# Detect if text has any Bengali characters
def is_bengali(text):
return re.search(r"[\u0980-\u09FF]", text) is not None
def predict_labels(text, threshold=0.5):
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = torch.sigmoid(outputs.logits).cpu().numpy()[0]
results = []
for i, p in enumerate(probs):
if p >= threshold: # multi-label condition
label = model.config.id2label[i]
results.append(f"{label_map.get(label, label)} ({p:.2f})")
if not results: # if nothing crosses threshold
best_idx = probs.argmax()
label = model.config.id2label[best_idx]
results = [f"{label_map.get(label, label)} ({probs[best_idx]:.2f})"]
return results
def chatbot(message, history):
try:
# Check language
if not is_bengali(message):
response = "⚠️ This chatbot only works for Bengali text. Please write in বাংলা."
else:
results = predict_labels(message, threshold=0.5)
response = " / ".join(results)
except Exception as e:
response = f"⚠️ Error: {str(e)}"
history = history or []
history.append((message, response))
return history, history, "" # clear textbox after sending
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("Bengali Sentiment Chatbot (Multi-Label)")
chatbot_ui = gr.Chatbot(height=400)
with gr.Row():
msg = gr.Textbox(placeholder="বাংলায় কিছু লিখুন...", scale=9)
send = gr.Button("Send", scale=1)
clear = gr.Button("🗑️ Clear Chat")
send.click(chatbot, [msg, chatbot_ui], [chatbot_ui, chatbot_ui, msg])
msg.submit(chatbot, [msg, chatbot_ui], [chatbot_ui, chatbot_ui, msg])
clear.click(lambda: None, None, chatbot_ui, queue=False)
demo.launch()