YAMITEK's picture
Rename app.py to app2.py
26d7bb7 verified
import streamlit as st
import onnxruntime as ort
import torch
import numpy as np
import pickle
@st.cache_resource
def load_model():
return ort.InferenceSession(r"C:\Users\ADMIN\Desktop\lstm_news_classifier (1).onnx")
@st.cache_data
def load_tokenizer():
with open(r"C:\Users\ADMIN\Desktop\tokenizer.pkl", "rb") as f:
tokenizer = pickle.load(f)
return tokenizer
@st.cache_data
def load_vocab():
try:
with open("vocab.pkl", "rb") as f:
vocab = pickle.load(f)
return vocab
except FileNotFoundError:
return None
# πŸ” Extract max_length from ONNX input shape
def get_input_length(session):
input_shape = session.get_inputs()[0].shape
return input_shape[1] if isinstance(input_shape[1], int) else 55 # fallback
def predict(text, session, tokenizer, vocab=None):
max_length = get_input_length(session)
if vocab:
tokens = tokenizer(text)
indices = [vocab.get(token, vocab.get('<unk>', 0)) for token in tokens]
else:
encoding = tokenizer.encode(text)
indices = encoding.ids if hasattr(encoding, "ids") else encoding["input_ids"]
padded = indices[:max_length] + [0] * (max_length - len(indices))
input_array = np.array([padded], dtype=np.int64)
inputs = {session.get_inputs()[0].name: input_array}
output = session.run(None, inputs)[0]
probs = torch.softmax(torch.tensor(output), dim=1)
pred = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred].item()
return pred, confidence
# πŸ–Ό Streamlit UI
st.set_page_config(page_title="Fake News Detector", page_icon="πŸ“°")
st.title("πŸ“° Fake News Detector")
url = "https://tse1.mm.bing.net/th?id=OIP.P_-960Qckr5FUEU3KvjCMwHaEc&pid=Api&rs=1&c=1&qlt=95&w=208&h=124"
st.image(url, width=400)
st.markdown(f"""
<style>
/* Set the background image for the entire app */
.stApp {{
background-color:#add8e6;
background-size: 100px;
background-repeat:no;
background-attachment: auto;
background-position:full;
}}
</style>
""", unsafe_allow_html=True)
user_input = st.text_area("Enter News Text:", height=100)
if st.button("Detect"):
if user_input.strip() == "":
st.warning("Please enter some text.")
else:
with st.spinner("Analyzing..."):
session = load_model()
tokenizer = load_tokenizer()
vocab = load_vocab()
label, confidence = predict(user_input, session, tokenizer, vocab)
label_name = "Fake" if label == 1 else "Real"
color = "πŸ”΄" if label == 1 else "🟒"
st.markdown(f"### Prediction: {color} **{label_name} News**")
st.markdown(f"**Confidence:** {confidence:.2%}")