ma4389's picture
Upload 3 files
61c79c6 verified
raw
history blame
2.69 kB
import torch
import torch.nn as nn
from transformers import T5Tokenizer, T5EncoderModel
import gradio as gr
import re
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
# === NLTK Downloads ===
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')
# === Preprocessing Function ===
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()
def preprocess_text(text):
text = re.sub(r'[^A-Za-z\s]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
text = text.lower()
tokens = word_tokenize(text)
tokens = [word for word in tokens if word not in stop_words]
tokens = [lemmatizer.lemmatize(word) for word in tokens]
return ' '.join(tokens)
# === Model Definition ===
class T5Classifier(nn.Module):
def __init__(self, model_name='t5-small', num_labels=2):
super(T5Classifier, self).__init__()
self.encoder = T5EncoderModel.from_pretrained(model_name)
self.classifier = nn.Linear(self.encoder.config.d_model, num_labels)
def forward(self, input_ids, attention_mask):
encoder_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
cls_representation = encoder_output.last_hidden_state[:, 0, :]
logits = self.classifier(cls_representation)
return logits
# === Load Model & Tokenizer ===
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5Classifier(model_name='t5-small', num_labels=2)
model.load_state_dict(torch.load("best_model.pth", map_location=device))
model.to(device)
model.eval()
# === Prediction Function ===
label_map = {0: "Negative", 1: "Positive"}
def predict_sentiment(text):
cleaned = preprocess_text(text)
inputs = tokenizer(cleaned, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
logits = model(input_ids=input_ids, attention_mask=attention_mask)
pred = torch.argmax(logits, dim=1).item()
return label_map[pred]
# === Gradio Interface ===
demo = gr.Interface(
fn=predict_sentiment,
inputs=gr.Textbox(label="Enter Movie Review"),
outputs=gr.Text(label="Predicted Sentiment"),
title="🎬 T5 Movie Review Classifier",
description="Enter a movie review, and the model will predict whether the sentiment is Positive or Negative."
)
demo.launch()