YAMITEK's picture
Update app.py
6035a10 verified
import streamlit as st
import torch
import re
import torch.nn as nn
import torch.nn.functional as F
import joblib
# Model parameters
vocab_size = 37852
embedding_dim = 45
hidden_units = 25
num_classes = 2
max_len = 55
# Define the LSTM model
class LSTMModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_units, num_classes):
super(LSTMModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_units, batch_first=True, dropout=0.2, bidirectional=True)
self.fc = nn.Linear(hidden_units * 2, num_classes)
def forward(self, x):
x = self.embedding(x)
output, _ = self.lstm(x)
x = output[:, -1, :]
x = self.fc(x)
return F.softmax(x, dim=1)
# Load model and tokenizer
model = LSTMModel(vocab_size, embedding_dim, hidden_units, num_classes)
model.load_state_dict(torch.load("news_classfication.pth", map_location=torch.device("cpu")))
model.eval()
tokenizer = joblib.load("tokenizer.pkl")
# Preprocessing function
def preprocess(text):
text = text.lower()
text = re.sub('https?://\S+|www\.\S+', '', text)
text = re.sub('\\W', ' ', text)
text = re.sub('\n', '', text)
text = re.sub(' +', ' ', text)
text = re.sub('^ ', '', text)
text = re.sub(' $', '', text)
# Tokenization
tokenized = tokenizer.encode(text).ids
# Padding or truncating to fixed length
max_length = 20
if len(tokenized) > max_length:
tokenized = tokenized[:max_length]
else:
tokenized += [0] * (max_length - len(tokenized))
# Convert to tensor
text_tensor = torch.tensor(tokenized, dtype=torch.long).unsqueeze(0)
return text_tensor
# Streamlit UI
st.set_page_config(page_title="Fake News Detector", page_icon="πŸ“°")
st.title("πŸ“° Fake News Detector")
# Display an image
url = "https://tse1.mm.bing.net/th?id=OIP.P_-960Qckr5FUEU3KvjCMwHaEc&pid=Api&rs=1&c=1&qlt=95&w=208&h=124"
st.image(url, width=400)
# Styling the background
st.markdown("""
<style>
.stApp {
background-color: #add8e6;
}
</style>
""", unsafe_allow_html=True)
# Text input
user_input = st.text_area(
"Enter News Text:",
value="Sheriff David Clarke Becomes An Internet Joke For Threatening To Poke People 'In The Eye'",
height=100
)
# Predict button
if st.button("Submit"):
if user_input.strip() == "":
st.warning("Please enter some text to classify.")
else:
input_tensor = preprocess(user_input)
output = model(input_tensor)
prediction = output.argmax().item()
if prediction == 0:
st.error("🚨 This is *Fake News*.")
else:
st.success("βœ… This is *Not Fake News*.")