import streamlit as st import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import re import pickle # Load word2idx with open('word2idx.pkl', 'rb') as f: word2idx = pickle.load(f) # Clean text function def clean_text(text): emoji_pattern = re.compile("[" u"\U0001F600-\U0001F64F" u"\U0001F300-\U0001F5FF" u"\U0001F680-\U0001F6FF" u"\U0001F1E0-\U0001F1FF" u"\U00002702-\U000027B0" u"\U000024C2-\U0001F251" "]+", flags=re.UNICODE) text = emoji_pattern.sub(r'', text) url = re.compile(r'https?://\S+|www\.\S+') text = url.sub(r'', text) text = text.replace('#', ' ') text = text.replace('@', ' ') symbols = re.compile(r'[^A-Za-z0-9 ]') text = symbols.sub(r'', text) text = text.lower() return text # Text to sequence function def text_to_sequence(text, word2idx, maxlen=55): words = text.split() seq = [word2idx.get(word, 0) for word in words] if len(seq) > maxlen: seq = seq[:maxlen] else: seq = [0]*(maxlen - len(seq)) + seq return np.array(seq) # Define the BiLSTM class class BiLSTM(nn.Module): def __init__(self, weights_matrix, output_size, hidden_dim, hidden_dim2, n_layers, drop_prob=0.5): super(BiLSTM, self).__init__() self.output_size = output_size self.n_layers = n_layers self.hidden_dim = hidden_dim # Embedding layer num_embeddings, embedding_dim = weights_matrix.size() self.embedding = nn.Embedding(num_embeddings, embedding_dim) self.embedding.weight.data.copy_(weights_matrix) self.embedding.weight.requires_grad = False # Freeze embedding layer # BiLSTM layer self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, dropout=drop_prob, bidirectional=True, batch_first=True) # Dropout layer self.dropout = nn.Dropout(0.3) # Fully connected layers self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim2) self.fc2 = nn.Linear(hidden_dim2, output_size) # Activation function self.sigmoid = nn.Sigmoid() def forward(self, x, hidden): batch_size = x.size(0) # Embedding embeds = self.embedding(x) # LSTM lstm_out, hidden = self.lstm(embeds, hidden) # Stack up LSTM outputs lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim * 2) # Dropout and fully connected layers out = self.dropout(lstm_out) out = self.fc1(out) out = F.relu(out) out = self.dropout(out) out = self.fc2(out) # Sigmoid activation sig_out = self.sigmoid(out) # Reshape to batch_size first sig_out = sig_out.view(batch_size, -1) sig_out = sig_out[:, -1] # Get last batch of labels return sig_out, hidden def init_hidden(self, batch_size, train_on_gpu=False): weight = next(self.parameters()).data layers = self.n_layers * 2 # Multiply by 2 for bidirectionality if train_on_gpu: hidden = (weight.new(layers, batch_size, self.hidden_dim).zero_().cuda(), weight.new(layers, batch_size, self.hidden_dim).zero_().cuda()) else: hidden = (weight.new(layers, batch_size, self.hidden_dim).zero_(), weight.new(layers, batch_size, self.hidden_dim).zero_()) return hidden # Load the embedding weights matrix weights_matrix = torch.tensor(np.load('weights_matrix.npy')) # Instantiate the model output_size = 1 hidden_dim = 128 hidden_dim2 = 64 n_layers = 2 net = BiLSTM(weights_matrix, output_size, hidden_dim, hidden_dim2, n_layers) # Load the model's state_dict net.load_state_dict(torch.load('state_dict.pt', map_location=torch.device('cpu'))) net.eval() # Streamlit app def main(): st.title("Disaster Tweet Classifier") st.write("Enter a tweet to classify whether it's about a real disaster or not.") user_input = st.text_area("Enter Tweet Text:") if st.button("Classify"): if user_input: # Preprocess input clean_input = clean_text(user_input) seq = text_to_sequence(clean_input, word2idx) input_tensor = torch.from_numpy(seq).unsqueeze(0).type(torch.LongTensor) # Initialize hidden state h = net.init_hidden(1, train_on_gpu=False) h = tuple([each.data for each in h]) # Make prediction with torch.no_grad(): output, h = net(input_tensor, h) prob = output.item() pred = int(torch.round(output).item()) # Display result if pred == 1: st.success(f"This tweet is about a **real disaster**. (Probability: {prob:.4f})") else: st.info(f"This tweet is **not about a real disaster**. (Probability: {prob:.4f})") else: st.warning("Please enter some text to classify.") if __name__ == '__main__': main()