MrUtakata commited on
Commit
db2d2c1
Β·
verified Β·
1 Parent(s): da55453

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -26
app.py CHANGED
@@ -7,59 +7,59 @@ import numpy as np
7
 
8
  from nltk.corpus import stopwords
9
  from nltk.tokenize import RegexpTokenizer
10
- from sklearn.neighbors import NearestNeighbors
11
  from sklearn.feature_extraction.text import TfidfVectorizer
12
 
13
- # β€”β€”β€” 1) Download NLTK data & set up tokenizer/stopwords β€”β€”β€”
14
  nltk.download('stopwords')
15
- STOP_WORDS = set(stopwords.words('english'))
16
- TOKENIZER = RegexpTokenizer(r'\w+')
17
 
18
  def preprocess_text(text: str) -> str:
19
- tokens = TOKENIZER.tokenize(text.lower())
20
- return " ".join([t for t in tokens if t not in STOP_WORDS])
21
 
22
- # β€”β€”β€” 2) Load saved artifacts once β€”β€”β€”
23
- @st.cache(allow_output_mutation=True)
24
- def load_artifacts():
25
  tfidf: TfidfVectorizer = joblib.load("tfidf_vectorizer.pkl")
26
- knn: NearestNeighbors = joblib.load("knn_model.pkl")
27
  sage_model: torch.nn.Module = joblib.load("sage_model.pkl")
28
  sage_model.eval()
29
- return tfidf, knn, sage_model
30
 
31
- tfidf, knn, sage_model = load_artifacts()
32
 
33
  # β€”β€”β€” 3) Streamlit UI β€”β€”β€”
34
  st.title("Disinformation Detection")
35
  st.write(
36
- """Enter a snippet of text below and click **Predict** to see
37
- whether it is more likely **True Information** or **Disinformation**."""
 
 
38
  )
39
 
40
- user_input = st.text_area("Article text", height=200)
41
 
42
  if st.button("Predict"):
43
  if not user_input.strip():
44
  st.warning("Please enter some text first.")
45
  else:
46
  # Preprocess & vectorize
47
- clean = preprocess_text(user_input)
48
- vec = tfidf.transform([clean]).toarray()
49
- x = torch.from_numpy(vec).float() # shape [1, D]
50
 
51
- # Build an β€œempty” graph so SAGEConv still runs (no neighbor messages)
52
  edge_index = torch.empty((2, 0), dtype=torch.long)
53
 
54
  # Inference
55
  with torch.no_grad():
56
- out = sage_model(x, edge_index) # [1, 2]
57
- probs = torch.exp(out).numpy()[0] # turn log‑softmax β†’ probs
58
 
59
- lst = [f"πŸ”΅ True information: {probs[1]:.2%}",
60
- f"πŸ”΄ Disinformation: {probs[0]:.2%}"]
61
  st.markdown("### Prediction probabilities")
62
- st.write("\n\n".join(lst))
 
63
 
64
- pred = "βœ… Likely TRUE" if probs[1] > probs[0] else "❌ Likely DISINFORMATION"
65
- st.markdown(f"## **{pred}**")
 
7
 
8
  from nltk.corpus import stopwords
9
  from nltk.tokenize import RegexpTokenizer
 
10
  from sklearn.feature_extraction.text import TfidfVectorizer
11
 
12
+ # β€”β€”β€” 1) NLTK setup β€”β€”β€”
13
  nltk.download('stopwords')
14
+ _STOP_WORDS = set(stopwords.words('english'))
15
+ _TOKENIZER = RegexpTokenizer(r'\w+')
16
 
17
  def preprocess_text(text: str) -> str:
18
+ tokens = _TOKENIZER.tokenize(text.lower())
19
+ return " ".join([t for t in tokens if t not in _STOP_WORDS])
20
 
21
+ # β€”β€”β€” 2) Load artifacts once β€”β€”β€”
22
+ @st.cache_resource
23
+ def load_resources():
24
  tfidf: TfidfVectorizer = joblib.load("tfidf_vectorizer.pkl")
 
25
  sage_model: torch.nn.Module = joblib.load("sage_model.pkl")
26
  sage_model.eval()
27
+ return tfidf, sage_model
28
 
29
+ tfidf, sage_model = load_resources()
30
 
31
  # β€”β€”β€” 3) Streamlit UI β€”β€”β€”
32
  st.title("Disinformation Detection")
33
  st.write(
34
+ """
35
+ Paste or type a snippet of text below and click **Predict**.
36
+ The model will output the probability it’s **True Information** vs. **Disinformation**.
37
+ """
38
  )
39
 
40
+ user_input = st.text_area("Your text here", height=200)
41
 
42
  if st.button("Predict"):
43
  if not user_input.strip():
44
  st.warning("Please enter some text first.")
45
  else:
46
  # Preprocess & vectorize
47
+ cleaned = preprocess_text(user_input)
48
+ vec = tfidf.transform([cleaned]).toarray()
49
+ x = torch.from_numpy(vec).float() # shape [1, D]
50
 
51
+ # Build an β€œempty” graph so the SAGEConv layers run (no neighbor messages)
52
  edge_index = torch.empty((2, 0), dtype=torch.long)
53
 
54
  # Inference
55
  with torch.no_grad():
56
+ logits = sage_model(x, edge_index) # [1, 2]
57
+ probs = torch.exp(logits).numpy()[0] # turn log‑softmax β†’ probs
58
 
59
+ # Display
 
60
  st.markdown("### Prediction probabilities")
61
+ st.write(f"β€’ πŸ”΅ True information: {probs[1]:.2%}")
62
+ st.write(f"β€’ πŸ”΄ Disinformation: {probs[0]:.2%}")
63
 
64
+ label = "βœ… Likely TRUE" if probs[1] > probs[0] else "❌ Likely DISINFORMATION"
65
+ st.markdown(f"## **{label}**")