junaid17 commited on
Commit
6779b8d
·
verified ·
1 Parent(s): 5ccd013

Upload 6 files

Browse files
BERT_CLASSIFIER1.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
BERT_MODEL.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47318e40a47b689c5d0cc90d41b345a4e3b0f15a2a4a01fd7916763fc5873e52
3
+ size 266456825
TOKENIZER/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
TOKENIZER/tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "extra_special_tokens": {},
49
+ "mask_token": "[MASK]",
50
+ "model_max_length": 512,
51
+ "never_split": null,
52
+ "pad_token": "[PAD]",
53
+ "sep_token": "[SEP]",
54
+ "strip_accents": null,
55
+ "tokenize_chinese_chars": true,
56
+ "tokenizer_class": "DistilBertTokenizer",
57
+ "unk_token": "[UNK]"
58
+ }
TOKENIZER/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers import DistilBertTokenizer, DistilBertModel
5
+ import time
6
+
7
+ # Set page config with dark theme
8
+ st.set_page_config(
9
+ page_title="TwittoBERT",
10
+ page_icon="🐦",
11
+ layout="centered",
12
+ initial_sidebar_state="expanded"
13
+ )
14
+
15
+ # Custom CSS for dark theme
16
+ st.markdown("""
17
+ <style>
18
+ :root {
19
+ --primary-color: #1DA1F2;
20
+ --background-color: #0F0F0F;
21
+ --secondary-background: #1E1E1E;
22
+ --text-color: #FFFFFF;
23
+ --font: sans-serif;
24
+ }
25
+
26
+ body {
27
+ background-color: var(--background-color);
28
+ color: var(--text-color);
29
+ font-family: var(--font);
30
+ }
31
+
32
+ .stApp {
33
+ background-color: var(--background-color);
34
+ }
35
+
36
+ .stTextInput>div>div>input {
37
+ background-color: var(--secondary-background);
38
+ color: var(--text-color);
39
+ border: 1px solid #333;
40
+ }
41
+
42
+ .stButton>button {
43
+ background-color: var(--primary-color);
44
+ color: white;
45
+ border-radius: 8px;
46
+ padding: 0.5rem 1rem;
47
+ border: none;
48
+ font-weight: bold;
49
+ transition: all 0.3s;
50
+ }
51
+
52
+ .stButton>button:hover {
53
+ background-color: #1991db;
54
+ transform: scale(1.02);
55
+ }
56
+
57
+ .prediction-box {
58
+ padding: 1.5rem;
59
+ border-radius: 10px;
60
+ margin: 1.5rem 0;
61
+ background-color: var(--secondary-background);
62
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
63
+ border-left: 5px solid var(--primary-color);
64
+ }
65
+
66
+ .header {
67
+ color: var(--primary-color);
68
+ }
69
+
70
+ .positive {
71
+ border-left-color: #4CAF50;
72
+ }
73
+
74
+ .neutral {
75
+ border-left-color: #FFCC00;
76
+ }
77
+
78
+ .negative {
79
+ border-left-color: #FF4D4D;
80
+ }
81
+
82
+ .sample-tweet {
83
+ padding: 0.5rem;
84
+ margin: 0.5rem 0;
85
+ border-radius: 5px;
86
+ background-color: var(--secondary-background);
87
+ cursor: pointer;
88
+ transition: all 0.2s;
89
+ }
90
+
91
+ .sample-tweet:hover {
92
+ background-color: #2A2A2A;
93
+ }
94
+ </style>
95
+ """, unsafe_allow_html=True)
96
+
97
+ # SentimentClassifier model definition
98
+ class SentimentClassifier(torch.nn.Module):
99
+ def __init__(self):
100
+ super(SentimentClassifier, self).__init__()
101
+ self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
102
+ for param in self.bert.parameters():
103
+ param.requires_grad = False
104
+ self.classifier = torch.nn.Sequential(
105
+ torch.nn.Linear(768, 256),
106
+ torch.nn.BatchNorm1d(256),
107
+ torch.nn.ReLU(),
108
+ torch.nn.Dropout(0.3),
109
+ torch.nn.Linear(256, 128),
110
+ torch.nn.BatchNorm1d(128),
111
+ torch.nn.ReLU(),
112
+ torch.nn.Dropout(0.3),
113
+ torch.nn.Linear(128, 64),
114
+ torch.nn.BatchNorm1d(64),
115
+ torch.nn.ReLU(),
116
+ torch.nn.Dropout(0.3),
117
+ torch.nn.Linear(64, 3)
118
+ )
119
+
120
+ def forward(self, input_ids, attention_mask):
121
+ bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
122
+ sentence_embeddings = bert_output.last_hidden_state[:, 0, :]
123
+ return self.classifier(sentence_embeddings)
124
+
125
+ # Load model and tokenizer
126
+ @st.cache_resource
127
+ def load_model():
128
+ model = SentimentClassifier()
129
+ model.load_state_dict(torch.load('BERT_MODEL.pth', map_location=torch.device('cpu')))
130
+ model.eval()
131
+ return model
132
+
133
+ @st.cache_resource
134
+ def load_tokenizer():
135
+ return DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
136
+
137
+ # Prediction function
138
+ def predict_sentiment(model, tokenizer, tweet):
139
+ inputs = tokenizer(
140
+ tweet,
141
+ padding="max_length",
142
+ max_length=200,
143
+ truncation=True,
144
+ return_tensors="pt"
145
+ )
146
+
147
+ input_ids = inputs["input_ids"]
148
+ attention_mask = inputs["attention_mask"]
149
+
150
+ with torch.no_grad():
151
+ logits = model(input_ids, attention_mask)
152
+ probs = F.softmax(logits, dim=1)
153
+ confidence, predicted_class = torch.max(probs, dim=1)
154
+
155
+ class_names = ["Negative", "Neutral", "Positive"]
156
+ label = class_names[predicted_class.item()]
157
+ confidence_percent = confidence.item() * 100
158
+
159
+ return label, confidence_percent
160
+
161
+ def main():
162
+ st.title("🐦 TwittoBERT")
163
+ st.markdown("Analyze the sentiment of tweets using a fine-tuned BERT model", unsafe_allow_html=True)
164
+
165
+ # Load model and tokenizer
166
+ try:
167
+ model = load_model()
168
+ tokenizer = load_tokenizer()
169
+ except Exception as e:
170
+ st.error(f"Error loading model: {str(e)}")
171
+ st.stop()
172
+
173
+ # Sample tweets
174
+ st.subheader("Try these sample tweets:")
175
+ sample_tweets = [
176
+ "I love this product! It's absolutely amazing! 😍",
177
+ "The service was okay, nothing special.",
178
+ "This is the worst experience I've ever had. Terrible!",
179
+ "Just had the best coffee of my life at this new café!",
180
+ "The movie was decent but could have been better.",
181
+ "I'm so frustrated with this terrible customer service!"
182
+ ]
183
+
184
+ cols = st.columns(2)
185
+ for i, tweet in enumerate(sample_tweets):
186
+ with cols[i % 2]:
187
+ if st.button(tweet[:50] + "..." if len(tweet) > 50 else tweet,
188
+ key=f"sample_{i}",
189
+ help="Click to analyze this tweet"):
190
+ st.session_state.sample_tweet = tweet
191
+
192
+ # Tweet input
193
+ tweet = st.text_area("Or enter your own tweet to analyze:",
194
+ height=100,
195
+ placeholder="Type your tweet here...",
196
+ value=st.session_state.get("sample_tweet", ""))
197
+
198
+ if st.button("Analyze Sentiment") and tweet:
199
+ with st.spinner("Analyzing sentiment..."):
200
+ time.sleep(0.5) # Simulate processing time
201
+ label, confidence = predict_sentiment(model, tokenizer, tweet)
202
+
203
+ # Display result with appropriate styling
204
+ if label == "Negative":
205
+ st.markdown(f"""
206
+ <div class="prediction-box negative">
207
+ <h3>Sentiment: {label}</h3>
208
+ <p>Confidence: {confidence:.2f}%</p>
209
+ </div>
210
+ """, unsafe_allow_html=True)
211
+ elif label == "Neutral":
212
+ st.markdown(f"""
213
+ <div class="prediction-box neutral">
214
+ <h3>Sentiment: {label}</h3>
215
+ <p>Confidence: {confidence:.2f}%</p>
216
+ </div>
217
+ """, unsafe_allow_html=True)
218
+ else:
219
+ st.markdown(f"""
220
+ <div class="prediction-box positive">
221
+ <h3>Sentiment: {label}</h3>
222
+ <p>Confidence: {confidence:.2f}%</p>
223
+ </div>
224
+ """, unsafe_allow_html=True)
225
+
226
+ # Sidebar info
227
+ st.sidebar.header("About")
228
+ st.sidebar.markdown("""
229
+ This app uses a fine-tuned DistilBERT model to analyze sentiment in tweets.
230
+ It can classify tweets as Positive, Negative, or Neutral with confidence scores.
231
+ """)
232
+
233
+ st.sidebar.header("Model Info")
234
+ st.sidebar.text("Model: DistilBERT-base-uncased")
235
+ st.sidebar.text("Classes: Negative, Neutral, Positive")
236
+
237
+ if __name__ == "__main__":
238
+ main()