|
|
|
|
|
import streamlit as st |
|
|
import torch |
|
|
from transformers import BertTokenizer, BertModel |
|
|
from torch import nn |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
|
class BERTClassifier(nn.Module): |
|
|
def __init__(self, bert_model_name, num_classes): |
|
|
super(BERTClassifier, self).__init__() |
|
|
self.bert = BertModel.from_pretrained(bert_model_name) |
|
|
self.dropout = nn.Dropout(0.1) |
|
|
self.fc = nn.Linear(self.bert.config.hidden_size, num_classes) |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
|
|
pooled_output = outputs.pooler_output |
|
|
x = self.dropout(pooled_output) |
|
|
logits = self.fc(x) |
|
|
return logits |
|
|
|
|
|
|
|
|
repo_id = "Makima57/sentiment-model" |
|
|
model_path = snapshot_download(repo_id=repo_id) |
|
|
bert_model_name = 'bert-base-uncased' |
|
|
num_classes = 2 |
|
|
max_length = 128 |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
model = BERTClassifier(bert_model_name, num_classes) |
|
|
model.load_state_dict(torch.load(f"{model_path}/pytorch_model.bin", map_location=device)) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained(f"{model_path}/tokenizer") |
|
|
|
|
|
|
|
|
def predict_sentiment(text, model, tokenizer, device, max_length=128): |
|
|
encoding = tokenizer(text, return_tensors='pt', max_length=max_length, padding='max_length', truncation=True) |
|
|
input_ids = encoding['input_ids'].to(device) |
|
|
attention_mask = encoding['attention_mask'].to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
_, preds = torch.max(outputs, dim=1) |
|
|
|
|
|
return "positive" if preds.item() == 1 else "negative" |
|
|
|
|
|
|
|
|
st.title("IMDB Movie Review Sentiment Analyzer") |
|
|
|
|
|
|
|
|
user_input = st.text_area("Enter a movie review:", "") |
|
|
|
|
|
|
|
|
if st.button("Analyze Sentiment"): |
|
|
if user_input.strip() != "": |
|
|
sentiment = predict_sentiment(user_input, model, tokenizer, device) |
|
|
st.write(f"The sentiment of the review is: **{sentiment}**") |
|
|
else: |
|
|
st.write("Please enter a valid review!") |
|
|
|
|
|
|