streamlitcsv02 / app.py
kelvinleong's picture
Update app.py
6948d18
import streamlit as st
import pandas as pd
import torch
from transformers import BertTokenizer, BertForSequenceClassification
# Load the BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
# Define the prediction function
def predict_sentiment(text):
# Tokenize the text
encoded_text = tokenizer.encode_plus(
text,
max_length=128,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
# Make the prediction
output = model(encoded_text['input_ids'], attention_mask=encoded_text['attention_mask'])
prediction = torch.argmax(output.logits, dim=1).item()
# Return the predicted sentiment
if prediction == 1:
return "Positive"
else:
return "Negative"
# Define the Streamlit app
def app():
st.title("BERT Sentiment Analysis (non pipeline")
st.write("Upload a CSV file with a 'text' column and I'll predict the sentiment for each row.")
# Get user input
file = st.file_uploader("Upload CSV file", type=["csv"])
if file is not None:
df = pd.read_csv(file)
# Make the predictions and add them to the dataframe
df['sentiment'] = df['text'].apply(predict_sentiment)
# Display the results
st.write(df)
if __name__ == '__main__':
app()