File size: 2,744 Bytes
5d3091e
 
da55453
 
ac77892
 
da55453
5d3091e
 
da55453
5d3091e
db2d2c1
ac77892
 
 
 
 
 
da55453
5d3091e
 
 
 
 
 
 
 
 
 
 
 
ac77892
5d3091e
da55453
5d3091e
ac77892
 
da55453
5d3091e
ac77892
da55453
ac77892
 
 
da55453
5d3091e
ac77892
 
5d3091e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da55453
5d3091e
ac77892
da55453
5d3091e
 
 
 
da55453
5d3091e
ac77892
 
5d3091e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# app.py

import streamlit as st
import torch
import joblib
import dill
import numpy as np
import gdown
import os

# Load assets with caching
@st.cache_resource
def load_assets():
    with open("preprocess_function.pkl", "rb") as f:
        preprocess_text = dill.load(f)
    tfidf = joblib.load("tfidf_vectorizer.pkl")
    model = joblib.load("sage_model.pkl")
    return preprocess_text, tfidf, model

# Download KNN model from Google Drive if not present
def ensure_knn_model():
    knn_path = "knn_model.pkl"
    if not os.path.exists(knn_path):
        gdown.download(
            "https://drive.google.com/uc?id=166HWcckEVofU1TzVpZPNzbHdjxV_SqpT",
            knn_path,
            quiet=False
        )
    return joblib.load(knn_path)

# Load models
preprocess_text, tfidf_vectorizer, sage_model = load_assets()
knn_model = ensure_knn_model()

# App UI
st.title("๐Ÿง  Disinformation Detection")
st.write("This app predicts whether a given news article is **real** or **disinformation** using a trained GraphSAGE model.")

# User input
user_input = st.text_area("๐Ÿ“ Enter a news article or headline:")

if st.button("Detect"):
    if user_input.strip() == "":
        st.warning("Please enter some text to analyze.")
    else:
        # STEP 1: Preprocess user input
        cleaned_text = preprocess_text(user_input)
        tfidf_vector = tfidf_vectorizer.transform([cleaned_text])
        input_feature = torch.tensor(tfidf_vector.toarray(), dtype=torch.float)

        # STEP 2: Get original feature set from KNN model
        original_features = torch.tensor(knn_model._fit_X, dtype=torch.float)

        # STEP 3: Combine input with training data features
        combined_features = torch.cat([original_features, input_feature], dim=0)

        # STEP 4: Build edge index using k-NN
        neighbors = knn_model.kneighbors(combined_features, return_distance=False)
        edge_list = []
        for idx, nbrs in enumerate(neighbors):
            for nbr in nbrs:
                if idx != nbr:
                    edge_list.append([idx, nbr])
        edge_index = torch.tensor(np.array(edge_list).T, dtype=torch.long)

        # STEP 5: Run inference on the last node (user input)
        sage_model.eval()
        with torch.no_grad():
            logits = sage_model(combined_features, edge_index)
            pred_node_logits = logits[-1]  # Last node is the user input
            prediction = torch.argmax(pred_node_logits).item()
            confidence = torch.exp(pred_node_logits)[prediction].item()

        # STEP 6: Display result
        label = "๐ŸŸข Real News" if prediction == 1 else "๐Ÿ”ด Disinformation"
        st.markdown(f"### Prediction: {label}")
        st.markdown(f"**Confidence:** {confidence:.2%}")