sentiment / app.py
Makima57's picture
Upload app.py with huggingface_hub
3f75dca verified
import streamlit as st
import torch
from transformers import BertTokenizer, BertModel
from torch import nn
from huggingface_hub import snapshot_download
# Define the BERTClassifier class
class BERTClassifier(nn.Module):
def __init__(self, bert_model_name, num_classes):
super(BERTClassifier, self).__init__()
self.bert = BertModel.from_pretrained(bert_model_name)
self.dropout = nn.Dropout(0.1)
self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
x = self.dropout(pooled_output)
logits = self.fc(x)
return logits
# Load the model and tokenizer from the repository
repo_id = "Makima57/sentiment-model"
model_path = snapshot_download(repo_id=repo_id)
bert_model_name = 'bert-base-uncased'
num_classes = 2
max_length = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load saved model
model = BERTClassifier(bert_model_name, num_classes)
model.load_state_dict(torch.load(f"{model_path}/pytorch_model.bin", map_location=device))
model.to(device)
model.eval()
# Load tokenizer
tokenizer = BertTokenizer.from_pretrained(f"{model_path}/tokenizer")
# Define prediction function
def predict_sentiment(text, model, tokenizer, device, max_length=128):
encoding = tokenizer(text, return_tensors='pt', max_length=max_length, padding='max_length', truncation=True)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
_, preds = torch.max(outputs, dim=1)
return "positive" if preds.item() == 1 else "negative"
# Streamlit app interface
st.title("IMDB Movie Review Sentiment Analyzer")
# Text input from user
user_input = st.text_area("Enter a movie review:", "")
# Predict sentiment when the user submits the input
if st.button("Analyze Sentiment"):
if user_input.strip() != "":
sentiment = predict_sentiment(user_input, model, tokenizer, device)
st.write(f"The sentiment of the review is: **{sentiment}**")
else:
st.write("Please enter a valid review!")