ma4389's picture
Update app.py
566ef70 verified
import gradio as gr
import torch
import re
import emoji
import numpy as np
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from transformers import T5Tokenizer, T5EncoderModel
import torch.nn as nn
# ========== Download NLTK data ==========
nltk.download('punkt_tab')
nltk.download('stopwords')
nltk.download('wordnet')
# ========== Correct Label Mapping ==========
label_map = {
0: "Irrelevant",
1: "Negative",
2: "Neutral",
3: "Positive"
}
# ========== Preprocessing ==========
def preprocess_text(text):
# Remove URLs
text = re.sub(r'http\S+|www\S+|https\S+', '', text)
# Remove non-alphabetic characters (keep spaces)
text = re.sub(r'[^a-zA-Z\s]', '', text)
# Convert to lowercase
text = text.lower()
# Convert emojis to text description
text = emoji.demojize(text)
# Tokenize
tokens = word_tokenize(text)
# Remove stopwords & short words
stop_words = set(stopwords.words('english'))
tokens = [word for word in tokens if word not in stop_words and len(word) > 2]
# Lemmatization
lemmatizer = WordNetLemmatizer()
tokens = [lemmatizer.lemmatize(word) for word in tokens]
# Join tokens back to string
return ' '.join(tokens)
# ========== Model Definition ==========
class T5Classifier(nn.Module):
def __init__(self, model_name='t5-small', num_labels=4):
super(T5Classifier, self).__init__()
self.encoder = T5EncoderModel.from_pretrained(model_name)
self.classifier = nn.Linear(self.encoder.config.d_model, num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
cls_token = outputs.last_hidden_state[:, 0, :]
logits = self.classifier(cls_token)
return logits
# ========== Load Model ==========
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5Classifier(model_name="t5-small", num_labels=4)
model.load_state_dict(torch.load("best_model.pth", map_location=device))
model.to(device)
model.eval()
# ========== Prediction Function ==========
def predict_sentiment(text):
cleaned = preprocess_text(text)
inputs = tokenizer(cleaned, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
logits = model(input_ids=input_ids, attention_mask=attention_mask)
probs = torch.softmax(logits, dim=1).cpu().numpy().flatten()
prediction = int(np.argmax(probs))
pred_label = label_map[prediction]
pred_probs = {label_map[i]: float(f"{probs[i]:.3f}") for i in range(len(probs))}
return pred_label, cleaned, pred_probs
# ========== Gradio Interface ==========
with gr.Blocks() as demo:
gr.Markdown("## 🧠 T5 Sentiment Classifier — Punkt Tab Layout")
with gr.Tabs():
with gr.Tab("🔍 Predict"):
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Enter Text", lines=4, placeholder="Type your sentence here...")
btn = gr.Button("Classify")
with gr.Column():
output_label = gr.Text(label="Predicted Sentiment")
cleaned_text = gr.Text(label="Preprocessed Text")
prob_output = gr.Label(label="Prediction Probabilities")
btn.click(fn=predict_sentiment, inputs=input_text, outputs=[output_label, cleaned_text, prob_output])
with gr.Tab("ℹ️ Info"):
gr.Markdown("""
### How it works:
- **Model**: T5 encoder with a linear classification head
- **Preprocessing**: Remove URLs, clean text, lowercase, tokenize, remove stopwords, lemmatize
- **Prediction**: Returns one of four classes (Irrelevant, Negative, Neutral, Positive)
""")
# ========== Launch ==========
demo.launch()