|
|
|
|
|
|
|
|
import streamlit as st |
|
|
import torch |
|
|
import joblib |
|
|
import dill |
|
|
import numpy as np |
|
|
import gdown |
|
|
import os |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_assets(): |
|
|
with open("preprocess_function.pkl", "rb") as f: |
|
|
preprocess_text = dill.load(f) |
|
|
tfidf = joblib.load("tfidf_vectorizer.pkl") |
|
|
model = joblib.load("sage_model.pkl") |
|
|
return preprocess_text, tfidf, model |
|
|
|
|
|
|
|
|
def ensure_knn_model(): |
|
|
knn_path = "knn_model.pkl" |
|
|
if not os.path.exists(knn_path): |
|
|
gdown.download( |
|
|
"https://drive.google.com/uc?id=166HWcckEVofU1TzVpZPNzbHdjxV_SqpT", |
|
|
knn_path, |
|
|
quiet=False |
|
|
) |
|
|
return joblib.load(knn_path) |
|
|
|
|
|
|
|
|
preprocess_text, tfidf_vectorizer, sage_model = load_assets() |
|
|
knn_model = ensure_knn_model() |
|
|
|
|
|
|
|
|
st.title("π§ Disinformation Detection") |
|
|
st.write("This app predicts whether a given news article is **real** or **disinformation** using a trained GraphSAGE model.") |
|
|
|
|
|
|
|
|
user_input = st.text_area("π Enter a news article or headline:") |
|
|
|
|
|
if st.button("Detect"): |
|
|
if user_input.strip() == "": |
|
|
st.warning("Please enter some text to analyze.") |
|
|
else: |
|
|
|
|
|
cleaned_text = preprocess_text(user_input) |
|
|
tfidf_vector = tfidf_vectorizer.transform([cleaned_text]) |
|
|
input_feature = torch.tensor(tfidf_vector.toarray(), dtype=torch.float) |
|
|
|
|
|
|
|
|
original_features = torch.tensor(knn_model._fit_X, dtype=torch.float) |
|
|
|
|
|
|
|
|
combined_features = torch.cat([original_features, input_feature], dim=0) |
|
|
|
|
|
|
|
|
neighbors = knn_model.kneighbors(combined_features, return_distance=False) |
|
|
edge_list = [] |
|
|
for idx, nbrs in enumerate(neighbors): |
|
|
for nbr in nbrs: |
|
|
if idx != nbr: |
|
|
edge_list.append([idx, nbr]) |
|
|
edge_index = torch.tensor(np.array(edge_list).T, dtype=torch.long) |
|
|
|
|
|
|
|
|
sage_model.eval() |
|
|
with torch.no_grad(): |
|
|
logits = sage_model(combined_features, edge_index) |
|
|
pred_node_logits = logits[-1] |
|
|
prediction = torch.argmax(pred_node_logits).item() |
|
|
confidence = torch.exp(pred_node_logits)[prediction].item() |
|
|
|
|
|
|
|
|
label = "π’ Real News" if prediction == 1 else "π΄ Disinformation" |
|
|
st.markdown(f"### Prediction: {label}") |
|
|
st.markdown(f"**Confidence:** {confidence:.2%}") |
|
|
|