import streamlit as st import torch import torch.nn.functional as F from transformers import DistilBertTokenizer, DistilBertModel import time # Set page config with dark theme st.set_page_config( page_title="TwittoBERT", page_icon="🐦", layout="centered", initial_sidebar_state="expanded" ) # Custom CSS for dark theme st.markdown(""" """, unsafe_allow_html=True) # SentimentClassifier model definition class SentimentClassifier(torch.nn.Module): def __init__(self): super(SentimentClassifier, self).__init__() self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased") for param in self.bert.parameters(): param.requires_grad = False self.classifier = torch.nn.Sequential( torch.nn.Linear(768, 256), torch.nn.BatchNorm1d(256), torch.nn.ReLU(), torch.nn.Dropout(0.3), torch.nn.Linear(256, 128), torch.nn.BatchNorm1d(128), torch.nn.ReLU(), torch.nn.Dropout(0.3), torch.nn.Linear(128, 64), torch.nn.BatchNorm1d(64), torch.nn.ReLU(), torch.nn.Dropout(0.3), torch.nn.Linear(64, 3) ) def forward(self, input_ids, attention_mask): bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask) sentence_embeddings = bert_output.last_hidden_state[:, 0, :] return self.classifier(sentence_embeddings) # Load model and tokenizer @st.cache_resource def load_model(): model = SentimentClassifier() model.load_state_dict(torch.load('BERT_MODEL.pth', map_location=torch.device('cpu'))) model.eval() return model @st.cache_resource def load_tokenizer(): return DistilBertTokenizer.from_pretrained('distilbert-base-uncased') # Prediction function def predict_sentiment(model, tokenizer, tweet): inputs = tokenizer( tweet, padding="max_length", max_length=200, truncation=True, return_tensors="pt" ) input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] with torch.no_grad(): logits = model(input_ids, attention_mask) probs = F.softmax(logits, dim=1) confidence, predicted_class = torch.max(probs, dim=1) class_names = ["Negative", "Neutral", "Positive"] label = class_names[predicted_class.item()] confidence_percent = confidence.item() * 100 return label, confidence_percent def main(): st.title("🐦 TwittoBERT") st.markdown("Analyze the sentiment of tweets using a fine-tuned BERT model", unsafe_allow_html=True) # Load model and tokenizer try: model = load_model() tokenizer = load_tokenizer() except Exception as e: st.error(f"Error loading model: {str(e)}") st.stop() # Sample tweets st.subheader("Try these sample tweets:") sample_tweets = [ "I love this product! It's absolutely amazing! 😍", "The service was okay, nothing special.", "This is the worst experience I've ever had. Terrible!", "Just had the best coffee of my life at this new café!", "The movie was decent but could have been better.", "I'm so frustrated with this terrible customer service!" ] cols = st.columns(2) for i, tweet in enumerate(sample_tweets): with cols[i % 2]: if st.button(tweet[:50] + "..." if len(tweet) > 50 else tweet, key=f"sample_{i}", help="Click to analyze this tweet"): st.session_state.sample_tweet = tweet # Tweet input tweet = st.text_area("Or enter your own tweet to analyze:", height=100, placeholder="Type your tweet here...", value=st.session_state.get("sample_tweet", "")) if st.button("Analyze Sentiment") and tweet: with st.spinner("Analyzing sentiment..."): time.sleep(0.5) # Simulate processing time label, confidence = predict_sentiment(model, tokenizer, tweet) # Display result with appropriate styling if label == "Negative": st.markdown(f"""

Sentiment: {label}

Confidence: {confidence:.2f}%

""", unsafe_allow_html=True) elif label == "Neutral": st.markdown(f"""

Sentiment: {label}

Confidence: {confidence:.2f}%

""", unsafe_allow_html=True) else: st.markdown(f"""

Sentiment: {label}

Confidence: {confidence:.2f}%

""", unsafe_allow_html=True) # Sidebar info st.sidebar.header("About") st.sidebar.markdown(""" This app uses a fine-tuned DistilBERT model to analyze sentiment in tweets. It can classify tweets as Positive, Negative, or Neutral with confidence scores. """) st.sidebar.header("Model Info") st.sidebar.text("Model: DistilBERT-base-uncased") st.sidebar.text("Classes: Negative, Neutral, Positive") if __name__ == "__main__": main()