File size: 2,744 Bytes
5d3091e da55453 ac77892 da55453 5d3091e da55453 5d3091e db2d2c1 ac77892 da55453 5d3091e ac77892 5d3091e da55453 5d3091e ac77892 da55453 5d3091e ac77892 da55453 ac77892 da55453 5d3091e ac77892 5d3091e da55453 5d3091e ac77892 da55453 5d3091e da55453 5d3091e ac77892 5d3091e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
# app.py
import streamlit as st
import torch
import joblib
import dill
import numpy as np
import gdown
import os
# Load assets with caching
@st.cache_resource
def load_assets():
with open("preprocess_function.pkl", "rb") as f:
preprocess_text = dill.load(f)
tfidf = joblib.load("tfidf_vectorizer.pkl")
model = joblib.load("sage_model.pkl")
return preprocess_text, tfidf, model
# Download KNN model from Google Drive if not present
def ensure_knn_model():
knn_path = "knn_model.pkl"
if not os.path.exists(knn_path):
gdown.download(
"https://drive.google.com/uc?id=166HWcckEVofU1TzVpZPNzbHdjxV_SqpT",
knn_path,
quiet=False
)
return joblib.load(knn_path)
# Load models
preprocess_text, tfidf_vectorizer, sage_model = load_assets()
knn_model = ensure_knn_model()
# App UI
st.title("๐ง Disinformation Detection")
st.write("This app predicts whether a given news article is **real** or **disinformation** using a trained GraphSAGE model.")
# User input
user_input = st.text_area("๐ Enter a news article or headline:")
if st.button("Detect"):
if user_input.strip() == "":
st.warning("Please enter some text to analyze.")
else:
# STEP 1: Preprocess user input
cleaned_text = preprocess_text(user_input)
tfidf_vector = tfidf_vectorizer.transform([cleaned_text])
input_feature = torch.tensor(tfidf_vector.toarray(), dtype=torch.float)
# STEP 2: Get original feature set from KNN model
original_features = torch.tensor(knn_model._fit_X, dtype=torch.float)
# STEP 3: Combine input with training data features
combined_features = torch.cat([original_features, input_feature], dim=0)
# STEP 4: Build edge index using k-NN
neighbors = knn_model.kneighbors(combined_features, return_distance=False)
edge_list = []
for idx, nbrs in enumerate(neighbors):
for nbr in nbrs:
if idx != nbr:
edge_list.append([idx, nbr])
edge_index = torch.tensor(np.array(edge_list).T, dtype=torch.long)
# STEP 5: Run inference on the last node (user input)
sage_model.eval()
with torch.no_grad():
logits = sage_model(combined_features, edge_index)
pred_node_logits = logits[-1] # Last node is the user input
prediction = torch.argmax(pred_node_logits).item()
confidence = torch.exp(pred_node_logits)[prediction].item()
# STEP 6: Display result
label = "๐ข Real News" if prediction == 1 else "๐ด Disinformation"
st.markdown(f"### Prediction: {label}")
st.markdown(f"**Confidence:** {confidence:.2%}")
|