Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import re | |
| import torch.nn as nn | |
| import joblib | |
| import torch.nn.functional as F | |
| st.title("News Classification") | |
| url = "https://tse1.mm.bing.net/th?id=OIP.P_-960Qckr5FUEU3KvjCMwHaEc&pid=Api&rs=1&c=1&qlt=95&w=208&h=124" | |
| st.image(url, width=400) | |
| st.markdown(f""" | |
| <style> | |
| /* Set the background image for the entire app */ | |
| .stApp {{ | |
| background-color:#add8e6; | |
| background-size: 100px; | |
| background-repeat:no; | |
| background-attachment: auto; | |
| background-position:full; | |
| }} | |
| </style> | |
| """, unsafe_allow_html=True) | |
| ## mopdel | |
| vocab_size = 37852 | |
| embedding_dim = 45 | |
| hidden_units = 25 | |
| num_classes = 2 | |
| max_len = 55 | |
| class LSTMModel(nn.Module): | |
| def __init__(self, vocab_size, embedding_dim, hidden_units, num_classes): | |
| super(LSTMModel, self).__init__() | |
| self.embedding = nn.Embedding(vocab_size, embedding_dim) | |
| self.lstm = nn.LSTM(embedding_dim, hidden_units, batch_first=True, dropout=0.2,bidirectional=True) | |
| self.fc = nn.Linear(hidden_units* 2, num_classes) | |
| def forward(self, x): | |
| x = self.embedding(x) | |
| output, _ = self.lstm(x) | |
| x = output[:, -1, :] | |
| x = self.fc(x) | |
| return F.softmax(x, dim=1) | |
| model = LSTMModel(vocab_size, embedding_dim, hidden_units, num_classes) | |
| ## load the weights | |
| model.load_state_dict(torch.load( "news_classfication.pth", map_location=torch.device("cpu"))) | |
| model.eval() | |
| tokenizer=joblib.load("tokenizer.pkl") | |
| def preprocess(words): | |
| normalized = [] | |
| for i in words: | |
| i = i.lower() | |
| # get rid of urlss | |
| i = re.sub('https?://\S+|www\.\S+', '', i) | |
| # get rid of non words and extra spaces | |
| i = re.sub('\\W', ' ', i) | |
| i = re.sub('\n', '', i) | |
| i = re.sub(' +', ' ', i) | |
| i = re.sub('^ ', '', i) | |
| i = re.sub(' $', '', i) | |
| normalized.append(i) | |
| text=[tokenizer.encode(text.lower()).ids for text in normalized] | |
| max_length = 20 | |
| flattened_text = [token for sublist in text for token in sublist] | |
| if len(flattened_text) > max_length: | |
| flattened_text = flattened_text[:max_length] | |
| else: | |
| flattened_text += [0] * (max_length - len(flattened_text)) | |
| text_tensor = torch.tensor(flattened_text, dtype=torch.long) | |
| text_tensor = text_tensor.unsqueeze(0) | |
| return text_tensor | |
| text=st.text_input("Enter the news Tittle ",value="Sheriff David Clarke Becomes An Internet Joke For Threatening To Poke People 'In The Eye'") | |
| if st.button("submit"): | |
| words=text.split() | |
| v=preprocess(words) | |
| output=model(v) | |
| st.write (output.argmax()) | |
| # review_status = { | |
| # 0: ("β Its a Fake news","#FF4500" ), # Green | |
| # 1: ("β Its not a Fake news ", "#32CD32") # Red-Orange | |
| # } | |
| # # Get message and color based on prediction | |
| # message, color = review_status.get(output.argmax(), ("β Unknown Prediction", "#808080")) | |
| # # Display styled result | |
| # st.markdown(f""" | |
| # <div style=" | |
| # padding: 15px; | |
| # background-color: {color}; | |
| # border-radius: 10px; | |
| # text-align: center; | |
| # font-size: 18px; | |
| # font-weight: bold; | |
| # color: white;"> | |
| # {message} | |
| # </div> | |
| # """, unsafe_allow_html=True) | |