Wendgan's picture
Upload 8 files
1544939 verified
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import re
import pickle
# Load word2idx
with open('word2idx.pkl', 'rb') as f:
word2idx = pickle.load(f)
# Clean text function
def clean_text(text):
emoji_pattern = re.compile("["
u"\U0001F600-\U0001F64F"
u"\U0001F300-\U0001F5FF"
u"\U0001F680-\U0001F6FF"
u"\U0001F1E0-\U0001F1FF"
u"\U00002702-\U000027B0"
u"\U000024C2-\U0001F251"
"]+", flags=re.UNICODE)
text = emoji_pattern.sub(r'', text)
url = re.compile(r'https?://\S+|www\.\S+')
text = url.sub(r'', text)
text = text.replace('#', ' ')
text = text.replace('@', ' ')
symbols = re.compile(r'[^A-Za-z0-9 ]')
text = symbols.sub(r'', text)
text = text.lower()
return text
# Text to sequence function
def text_to_sequence(text, word2idx, maxlen=55):
words = text.split()
seq = [word2idx.get(word, 0) for word in words]
if len(seq) > maxlen:
seq = seq[:maxlen]
else:
seq = [0]*(maxlen - len(seq)) + seq
return np.array(seq)
# Define the BiLSTM class
class BiLSTM(nn.Module):
def __init__(self, weights_matrix, output_size, hidden_dim, hidden_dim2, n_layers, drop_prob=0.5):
super(BiLSTM, self).__init__()
self.output_size = output_size
self.n_layers = n_layers
self.hidden_dim = hidden_dim
# Embedding layer
num_embeddings, embedding_dim = weights_matrix.size()
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.embedding.weight.data.copy_(weights_matrix)
self.embedding.weight.requires_grad = False # Freeze embedding layer
# BiLSTM layer
self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, dropout=drop_prob, bidirectional=True, batch_first=True)
# Dropout layer
self.dropout = nn.Dropout(0.3)
# Fully connected layers
self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim2)
self.fc2 = nn.Linear(hidden_dim2, output_size)
# Activation function
self.sigmoid = nn.Sigmoid()
def forward(self, x, hidden):
batch_size = x.size(0)
# Embedding
embeds = self.embedding(x)
# LSTM
lstm_out, hidden = self.lstm(embeds, hidden)
# Stack up LSTM outputs
lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim * 2)
# Dropout and fully connected layers
out = self.dropout(lstm_out)
out = self.fc1(out)
out = F.relu(out)
out = self.dropout(out)
out = self.fc2(out)
# Sigmoid activation
sig_out = self.sigmoid(out)
# Reshape to batch_size first
sig_out = sig_out.view(batch_size, -1)
sig_out = sig_out[:, -1] # Get last batch of labels
return sig_out, hidden
def init_hidden(self, batch_size, train_on_gpu=False):
weight = next(self.parameters()).data
layers = self.n_layers * 2 # Multiply by 2 for bidirectionality
if train_on_gpu:
hidden = (weight.new(layers, batch_size, self.hidden_dim).zero_().cuda(),
weight.new(layers, batch_size, self.hidden_dim).zero_().cuda())
else:
hidden = (weight.new(layers, batch_size, self.hidden_dim).zero_(),
weight.new(layers, batch_size, self.hidden_dim).zero_())
return hidden
# Load the embedding weights matrix
weights_matrix = torch.tensor(np.load('weights_matrix.npy'))
# Instantiate the model
output_size = 1
hidden_dim = 128
hidden_dim2 = 64
n_layers = 2
net = BiLSTM(weights_matrix, output_size, hidden_dim, hidden_dim2, n_layers)
# Load the model's state_dict
net.load_state_dict(torch.load('state_dict.pt', map_location=torch.device('cpu')))
net.eval()
# Streamlit app
def main():
st.title("Disaster Tweet Classifier")
st.write("Enter a tweet to classify whether it's about a real disaster or not.")
user_input = st.text_area("Enter Tweet Text:")
if st.button("Classify"):
if user_input:
# Preprocess input
clean_input = clean_text(user_input)
seq = text_to_sequence(clean_input, word2idx)
input_tensor = torch.from_numpy(seq).unsqueeze(0).type(torch.LongTensor)
# Initialize hidden state
h = net.init_hidden(1, train_on_gpu=False)
h = tuple([each.data for each in h])
# Make prediction
with torch.no_grad():
output, h = net(input_tensor, h)
prob = output.item()
pred = int(torch.round(output).item())
# Display result
if pred == 1:
st.success(f"This tweet is about a **real disaster**. (Probability: {prob:.4f})")
else:
st.info(f"This tweet is **not about a real disaster**. (Probability: {prob:.4f})")
else:
st.warning("Please enter some text to classify.")
if __name__ == '__main__':
main()