Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import onnxruntime as ort | |
| import torch | |
| import numpy as np | |
| import pickle | |
| def load_model(): | |
| return ort.InferenceSession(r"C:\Users\ADMIN\Desktop\lstm_news_classifier (1).onnx") | |
| def load_tokenizer(): | |
| with open(r"C:\Users\ADMIN\Desktop\tokenizer.pkl", "rb") as f: | |
| tokenizer = pickle.load(f) | |
| return tokenizer | |
| def load_vocab(): | |
| try: | |
| with open("vocab.pkl", "rb") as f: | |
| vocab = pickle.load(f) | |
| return vocab | |
| except FileNotFoundError: | |
| return None | |
| # π Extract max_length from ONNX input shape | |
| def get_input_length(session): | |
| input_shape = session.get_inputs()[0].shape | |
| return input_shape[1] if isinstance(input_shape[1], int) else 55 # fallback | |
| def predict(text, session, tokenizer, vocab=None): | |
| max_length = get_input_length(session) | |
| if vocab: | |
| tokens = tokenizer(text) | |
| indices = [vocab.get(token, vocab.get('<unk>', 0)) for token in tokens] | |
| else: | |
| encoding = tokenizer.encode(text) | |
| indices = encoding.ids if hasattr(encoding, "ids") else encoding["input_ids"] | |
| padded = indices[:max_length] + [0] * (max_length - len(indices)) | |
| input_array = np.array([padded], dtype=np.int64) | |
| inputs = {session.get_inputs()[0].name: input_array} | |
| output = session.run(None, inputs)[0] | |
| probs = torch.softmax(torch.tensor(output), dim=1) | |
| pred = torch.argmax(probs, dim=1).item() | |
| confidence = probs[0][pred].item() | |
| return pred, confidence | |
| # πΌ Streamlit UI | |
| st.set_page_config(page_title="Fake News Detector", page_icon="π°") | |
| st.title("π° Fake News Detector") | |
| url = "https://tse1.mm.bing.net/th?id=OIP.P_-960Qckr5FUEU3KvjCMwHaEc&pid=Api&rs=1&c=1&qlt=95&w=208&h=124" | |
| st.image(url, width=400) | |
| st.markdown(f""" | |
| <style> | |
| /* Set the background image for the entire app */ | |
| .stApp {{ | |
| background-color:#add8e6; | |
| background-size: 100px; | |
| background-repeat:no; | |
| background-attachment: auto; | |
| background-position:full; | |
| }} | |
| </style> | |
| """, unsafe_allow_html=True) | |
| user_input = st.text_area("Enter News Text:", height=100) | |
| if st.button("Detect"): | |
| if user_input.strip() == "": | |
| st.warning("Please enter some text.") | |
| else: | |
| with st.spinner("Analyzing..."): | |
| session = load_model() | |
| tokenizer = load_tokenizer() | |
| vocab = load_vocab() | |
| label, confidence = predict(user_input, session, tokenizer, vocab) | |
| label_name = "Fake" if label == 1 else "Real" | |
| color = "π΄" if label == 1 else "π’" | |
| st.markdown(f"### Prediction: {color} **{label_name} News**") | |
| st.markdown(f"**Confidence:** {confidence:.2%}") | |