Mummia-99's picture
Update app.py
2989196 verified
import streamlit as st
import torch
import re
import torch.nn as nn
import joblib
import torch.nn.functional as F
st.title("News Classification")
url = "https://tse1.mm.bing.net/th?id=OIP.P_-960Qckr5FUEU3KvjCMwHaEc&pid=Api&rs=1&c=1&qlt=95&w=208&h=124"
st.image(url, width=400)
st.markdown(f"""
<style>
/* Set the background image for the entire app */
.stApp {{
background-color:#add8e6;
background-size: 100px;
background-repeat:no;
background-attachment: auto;
background-position:full;
}}
</style>
""", unsafe_allow_html=True)
## mopdel
vocab_size = 37852
embedding_dim = 45
hidden_units = 25
num_classes = 2
max_len = 55
class LSTMModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_units, num_classes):
super(LSTMModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_units, batch_first=True, dropout=0.2,bidirectional=True)
self.fc = nn.Linear(hidden_units* 2, num_classes)
def forward(self, x):
x = self.embedding(x)
output, _ = self.lstm(x)
x = output[:, -1, :]
x = self.fc(x)
return F.softmax(x, dim=1)
model = LSTMModel(vocab_size, embedding_dim, hidden_units, num_classes)
## load the weights
model.load_state_dict(torch.load( "news_classfication.pth", map_location=torch.device("cpu")))
model.eval()
tokenizer=joblib.load("tokenizer.pkl")
def preprocess(words):
normalized = []
for i in words:
i = i.lower()
# get rid of urlss
i = re.sub('https?://\S+|www\.\S+', '', i)
# get rid of non words and extra spaces
i = re.sub('\\W', ' ', i)
i = re.sub('\n', '', i)
i = re.sub(' +', ' ', i)
i = re.sub('^ ', '', i)
i = re.sub(' $', '', i)
normalized.append(i)
text=[tokenizer.encode(text.lower()).ids for text in normalized]
max_length = 20
flattened_text = [token for sublist in text for token in sublist]
if len(flattened_text) > max_length:
flattened_text = flattened_text[:max_length]
else:
flattened_text += [0] * (max_length - len(flattened_text))
text_tensor = torch.tensor(flattened_text, dtype=torch.long)
text_tensor = text_tensor.unsqueeze(0)
return text_tensor
text=st.text_input("Enter the news Tittle ",value="Sheriff David Clarke Becomes An Internet Joke For Threatening To Poke People 'In The Eye'")
if st.button("submit"):
words=text.split()
v=preprocess(words)
output=model(v)
st.write (output.argmax())
# review_status = {
# 0: ("βœ… Its a Fake news","#FF4500" ), # Green
# 1: ("❌ Its not a Fake news ", "#32CD32") # Red-Orange
# }
# # Get message and color based on prediction
# message, color = review_status.get(output.argmax(), ("❓ Unknown Prediction", "#808080"))
# # Display styled result
# st.markdown(f"""
# <div style="
# padding: 15px;
# background-color: {color};
# border-radius: 10px;
# text-align: center;
# font-size: 18px;
# font-weight: bold;
# color: white;">
# {message}
# </div>
# """, unsafe_allow_html=True)