ddgnn / app.py
MrUtakata's picture
Update app.py
5d3091e verified
# app.py
import streamlit as st
import torch
import joblib
import dill
import numpy as np
import gdown
import os
# Load assets with caching
@st.cache_resource
def load_assets():
with open("preprocess_function.pkl", "rb") as f:
preprocess_text = dill.load(f)
tfidf = joblib.load("tfidf_vectorizer.pkl")
model = joblib.load("sage_model.pkl")
return preprocess_text, tfidf, model
# Download KNN model from Google Drive if not present
def ensure_knn_model():
knn_path = "knn_model.pkl"
if not os.path.exists(knn_path):
gdown.download(
"https://drive.google.com/uc?id=166HWcckEVofU1TzVpZPNzbHdjxV_SqpT",
knn_path,
quiet=False
)
return joblib.load(knn_path)
# Load models
preprocess_text, tfidf_vectorizer, sage_model = load_assets()
knn_model = ensure_knn_model()
# App UI
st.title("🧠 Disinformation Detection")
st.write("This app predicts whether a given news article is **real** or **disinformation** using a trained GraphSAGE model.")
# User input
user_input = st.text_area("πŸ“ Enter a news article or headline:")
if st.button("Detect"):
if user_input.strip() == "":
st.warning("Please enter some text to analyze.")
else:
# STEP 1: Preprocess user input
cleaned_text = preprocess_text(user_input)
tfidf_vector = tfidf_vectorizer.transform([cleaned_text])
input_feature = torch.tensor(tfidf_vector.toarray(), dtype=torch.float)
# STEP 2: Get original feature set from KNN model
original_features = torch.tensor(knn_model._fit_X, dtype=torch.float)
# STEP 3: Combine input with training data features
combined_features = torch.cat([original_features, input_feature], dim=0)
# STEP 4: Build edge index using k-NN
neighbors = knn_model.kneighbors(combined_features, return_distance=False)
edge_list = []
for idx, nbrs in enumerate(neighbors):
for nbr in nbrs:
if idx != nbr:
edge_list.append([idx, nbr])
edge_index = torch.tensor(np.array(edge_list).T, dtype=torch.long)
# STEP 5: Run inference on the last node (user input)
sage_model.eval()
with torch.no_grad():
logits = sage_model(combined_features, edge_index)
pred_node_logits = logits[-1] # Last node is the user input
prediction = torch.argmax(pred_node_logits).item()
confidence = torch.exp(pred_node_logits)[prediction].item()
# STEP 6: Display result
label = "🟒 Real News" if prediction == 1 else "πŸ”΄ Disinformation"
st.markdown(f"### Prediction: {label}")
st.markdown(f"**Confidence:** {confidence:.2%}")