Spaces:
Sleeping
Sleeping
File size: 2,883 Bytes
18e75ed | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | import streamlit as st
import onnxruntime as ort
import torch
import numpy as np
import pickle
@st.cache_resource
def load_model():
return ort.InferenceSession(r"C:\Users\ADMIN\Desktop\lstm_news_classifier (1).onnx")
@st.cache_data
def load_tokenizer():
with open(r"C:\Users\ADMIN\Desktop\tokenizer.pkl", "rb") as f:
tokenizer = pickle.load(f)
return tokenizer
@st.cache_data
def load_vocab():
try:
with open("vocab.pkl", "rb") as f:
vocab = pickle.load(f)
return vocab
except FileNotFoundError:
return None
# 🔁 Extract max_length from ONNX input shape
def get_input_length(session):
input_shape = session.get_inputs()[0].shape
return input_shape[1] if isinstance(input_shape[1], int) else 55 # fallback
def predict(text, session, tokenizer, vocab=None):
max_length = get_input_length(session)
if vocab:
tokens = tokenizer(text)
indices = [vocab.get(token, vocab.get('<unk>', 0)) for token in tokens]
else:
encoding = tokenizer.encode(text)
indices = encoding.ids if hasattr(encoding, "ids") else encoding["input_ids"]
padded = indices[:max_length] + [0] * (max_length - len(indices))
input_array = np.array([padded], dtype=np.int64)
inputs = {session.get_inputs()[0].name: input_array}
output = session.run(None, inputs)[0]
probs = torch.softmax(torch.tensor(output), dim=1)
pred = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred].item()
return pred, confidence
# 🖼 Streamlit UI
st.set_page_config(page_title="Fake News Detector", page_icon="📰")
st.title("📰 Fake News Detector")
url = "https://tse1.mm.bing.net/th?id=OIP.P_-960Qckr5FUEU3KvjCMwHaEc&pid=Api&rs=1&c=1&qlt=95&w=208&h=124"
st.image(url, width=400)
st.markdown(f"""
<style>
/* Set the background image for the entire app */
.stApp {{
background-color:#add8e6;
background-size: 100px;
background-repeat:no;
background-attachment: auto;
background-position:full;
}}
</style>
""", unsafe_allow_html=True)
user_input = st.text_area("Enter News Text:", height=100)
if st.button("Detect"):
if user_input.strip() == "":
st.warning("Please enter some text.")
else:
with st.spinner("Analyzing..."):
session = load_model()
tokenizer = load_tokenizer()
vocab = load_vocab()
label, confidence = predict(user_input, session, tokenizer, vocab)
label_name = "Fake" if label == 1 else "Real"
color = "🔴" if label == 1 else "🟢"
st.markdown(f"### Prediction: {color} **{label_name} News**")
st.markdown(f"**Confidence:** {confidence:.2%}")
|