Mohammedmarzuk17 commited on
Commit
281b438
·
verified ·
1 Parent(s): fd336f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -154
app.py CHANGED
@@ -1,174 +1,91 @@
1
  import gradio as gr
2
- import requests, feedparser, time, threading, re, json, os
3
- from concurrent.futures import ThreadPoolExecutor
4
- from sentence_transformers import SentenceTransformer, util
5
- import nltk
6
- from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
7
 
8
  # ---------------------------
9
- # NLTK setup
10
  # ---------------------------
11
- nltk.download('punkt')
12
-
13
- # ---------------------------
14
- # Models
15
- # ---------------------------
16
- # Sentence embeddings for semantic similarity
17
- embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
18
-
19
- # Claim classifier
20
- claim_model_name = "MoritzLaurer/DeBERTa-v3-base-mnli"
21
- tokenizer = AutoTokenizer.from_pretrained(claim_model_name, use_fast=False)
22
- model = AutoModelForSequenceClassification.from_pretrained(claim_model_name)
23
- claim_classifier = pipeline("zero-shot-classification", model=model, tokenizer=tokenizer)
24
  claim_labels = ["factual claim", "opinion", "personal anecdote", "other"]
25
 
26
- # AI detector
27
  ai_detect_model_name = "roberta-base-openai-detector"
28
- ai_detector = pipeline("text-classification", model=ai_detect_model_name)
29
 
30
- # NLI pipeline
31
  nli_model_name = "valhalla/distilbart-mnli-12-3"
32
- nli_pipeline = pipeline("text-classification", model=nli_model_name, tokenizer=nli_model_name)
33
-
34
- # ---------------------------
35
- # Evidence sources
36
- # ---------------------------
37
- RSS_FEEDS = [
38
- "https://www.snopes.com/feed/",
39
- "https://www.politifact.com/rss/factchecks/",
40
- "https://www.factcheck.org/feed/",
41
- ]
42
- RSS_CACHE = []
43
- CACHE_TTL = 60 * 60 * 3 # 3 hours
44
- RSS_LAST_FETCH = 0
45
-
46
- # ---------------------------
47
- # Helpers
48
- # ---------------------------
49
- def clean_text(text):
50
- text = re.sub(r'<.*?>', '', text)
51
- text = re.sub(r'\s+', ' ', text)
52
- return text.strip()
53
-
54
- def fetch_rss_articles():
55
- articles = []
56
- for url in RSS_FEEDS:
57
- try:
58
- feed = feedparser.parse(url)
59
- for entry in feed.entries[:10]:
60
- title = clean_text(entry.get("title", ""))
61
- summary = clean_text(entry.get("summary", ""))
62
- articles.append({"title": title, "summary": summary})
63
- except Exception:
64
- continue
65
- return articles
66
-
67
- def refresh_rss_cache(force=False):
68
- global RSS_CACHE, RSS_LAST_FETCH
69
- now = time.time()
70
- if force or (now - RSS_LAST_FETCH > CACHE_TTL) or not RSS_CACHE:
71
- RSS_CACHE = fetch_rss_articles()
72
- RSS_LAST_FETCH = now
73
-
74
- def start_rss_refresher():
75
- def loop():
76
- while True:
77
- refresh_rss_cache(force=True)
78
- time.sleep(CACHE_TTL)
79
- t = threading.Thread(target=loop, daemon=True)
80
- t.start()
81
 
82
  # ---------------------------
83
- # Claim extraction
84
  # ---------------------------
85
- def extract_claims(text):
86
- sentences = re.split(r'(?<=[.!?;\n])\s+', text)
87
- claims = []
88
  for s in sentences:
89
- s = s.strip()
90
- if len(s) < 15:
91
- continue
92
  out = claim_classifier(s, claim_labels)
93
- if "factual claim" in out["labels"] and out["scores"][out["labels"].index("factual claim")] > 0.25:
94
- claims.append(s)
95
- return claims[:10]
96
-
97
- # ---------------------------
98
- # Semantic RSS matching
99
- # ---------------------------
100
- def match_rss_semantic(claim, top_k=2):
101
- if not RSS_CACHE:
102
- return []
103
- claim_emb = embedding_model.encode(claim, convert_to_tensor=True)
104
- summaries = [a["summary"] for a in RSS_CACHE]
105
- text_embs = embedding_model.encode(summaries, convert_to_tensor=True)
106
- scores = util.pytorch_cos_sim(claim_emb, text_embs).cpu().numpy()[0]
107
- top_idx = scores.argsort()[::-1][:top_k]
108
- matched = [summaries[i] for i in top_idx if scores[i] > 0.3]
109
- return matched
110
-
111
- # ---------------------------
112
- # NLI & AI detection
113
- # ---------------------------
114
- def process_evidence_pair(claim, evidence):
115
- out = nli_pipeline(f"{claim} </s></s> {evidence}")[0]
116
- label = out['label']
117
- score = out['score']
118
- simplified_label = "Uncertain"
119
- if score > 0.6:
120
- simplified_label = "True" if label == "ENTAILMENT" else "False" if label == "CONTRADICTION" else "Uncertain"
121
 
122
- ai_out = ai_detector(claim)[0]
123
- ai_score = 1 - ai_out['score'] if ai_out['label'] != "Fake" else ai_out['score']
124
-
125
- trustworthiness = round((score * 0.7 + ai_score * 0.3) * 100, 1)
126
- return {
127
- "text": evidence[:300]+"..." if len(evidence)>300 else evidence,
128
- "label": simplified_label,
129
- "score": round(score,3),
130
- "trustworthiness": trustworthiness
131
- }
132
-
133
- # ---------------------------
134
- # Fact-checking
135
- # ---------------------------
136
- def fact_check(claims):
137
  results = []
138
- refresh_rss_cache()
139
- with ThreadPoolExecutor(max_workers=5) as executor:
140
- for c in claims:
141
- evidence = match_rss_semantic(c)
142
- if not evidence:
143
- results.append({"claim": c, "evidence": [], "trustworthiness": 0.0})
144
- continue
145
- futures = [executor.submit(process_evidence_pair, c, e) for e in evidence]
146
- top_evidence = [f.result() for f in futures]
147
- results.append({"claim": c, "evidence": top_evidence})
148
  return results
149
 
150
- # ---------------------------
151
- # Predict function
152
- # ---------------------------
153
- def predict(page_text=""):
154
- claims = extract_claims(page_text)
155
- fc_results = fact_check(claims) if claims else []
156
- return {"claims": claims, "fact_checking": fc_results}
157
-
158
- # ---------------------------
159
- # Gradio UI
160
- # ---------------------------
161
- with gr.Blocks() as demo:
162
- gr.Markdown("## EduShield AI - Fact-Checking with AI Models")
163
- page_input = gr.Textbox(label="Paste page text", lines=10)
164
- predict_btn = gr.Button("Run Predict")
165
- output_json = gr.JSON(label="Results")
166
- predict_btn.click(fn=predict, inputs=[page_input], outputs=output_json)
167
 
168
  # ---------------------------
169
- # Launch
170
- # ---------------------------
171
- if __name__ == "__main__":
172
- refresh_rss_cache(force=True)
173
- start_rss_refresher()
174
- demo.launch(server_name="0.0.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import pipeline
 
 
 
 
3
 
4
  # ---------------------------
5
+ # Load Models
6
  # ---------------------------
7
+ claim_model_name = "microsoft/deberta-v3-base-zeroshot-v1.1"
8
+ claim_classifier = pipeline("zero-shot-classification", model=claim_model_name, device=0)
 
 
 
 
 
 
 
 
 
 
 
9
  claim_labels = ["factual claim", "opinion", "personal anecdote", "other"]
10
 
 
11
  ai_detect_model_name = "roberta-base-openai-detector"
12
+ ai_detector = pipeline("text-classification", model=ai_detect_model_name, device=0)
13
 
 
14
  nli_model_name = "valhalla/distilbart-mnli-12-3"
15
+ nli_pipeline = pipeline("text-classification", model=nli_model_name, tokenizer=nli_model_name, device=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # ---------------------------
18
+ # Functions
19
  # ---------------------------
20
+ def extract_claims(page_text):
21
+ sentences = [s.strip() for s in page_text.split(".") if len(s.strip()) > 5]
22
+ results = []
23
  for s in sentences:
 
 
 
24
  out = claim_classifier(s, claim_labels)
25
+ if out["labels"][0] == "factual claim":
26
+ results.append(s)
27
+ return results[:5]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ def detect_ai(texts):
30
+ if isinstance(texts, str):
31
+ texts = [texts]
 
 
 
 
 
 
 
 
 
 
 
 
32
  results = []
33
+ for t in texts:
34
+ out = ai_detector(t)
35
+ results.append({"text": t, "label": out[0]["label"], "score": round(out[0]["score"], 3)})
 
 
 
 
 
 
 
36
  return results
37
 
38
+ def fact_check(claims, evidence_text):
39
+ if isinstance(claims, str):
40
+ claims = [claims]
41
+ results = []
42
+ for c in claims:
43
+ out = nli_pipeline(hypothesis=c, sequence_pair=evidence_text)
44
+ results.append({"claim": c, "label": out[0]["label"], "score": round(out[0]["score"], 3)})
45
+ return results
 
 
 
 
 
 
 
 
 
46
 
47
  # ---------------------------
48
+ # Unified Predict Function
49
+ # ---------------------------
50
+ def predict(page_text="", selected_text="", evidence_text=""):
51
+ """
52
+ 1. Extract top 5 claims from page_text
53
+ 2. Run AI Detection on claims + selected_text
54
+ 3. Run Fact-Checking on claims + evidence_text if provided
55
+ """
56
+ # Extract claims
57
+ claims = extract_claims(page_text) if page_text else []
58
+ ...
59
+ ... # Combine claims + selected text for AI detection
60
+ ... ai_input = claims.copy()
61
+ ... if selected_text:
62
+ ... ai_input.append(selected_text)
63
+ ... ai_results = detect_ai(ai_input) if ai_input else []
64
+ ...
65
+ ... # Fact-checking: only if evidence is provided
66
+ ... fc_results = fact_check(claims + ([selected_text] if selected_text else []), evidence_text) if evidence_text else []
67
+ ...
68
+ ... return {
69
+ ... "claims": claims,
70
+ ... "ai_detection": ai_results,
71
+ ... "fact_checking": fc_results
72
+ ... }
73
+ ...
74
+ ... # ---------------------------
75
+ ... # Gradio UI
76
+ ... # ---------------------------
77
+ ... with gr.Blocks() as demo:
78
+ ... gr.Markdown("## EduShield AI Backend - Predict API & UI")
79
+ ...
80
+ ... page_text_input = gr.Textbox(label="Full Page Text", lines=10, placeholder="Paste page text here...")
81
+ ... selected_text_input = gr.Textbox(label="Selected Text", lines=5, placeholder="Paste selected text here...")
82
+ ... evidence_input = gr.Textbox(label="Evidence Text", lines=5, placeholder="Paste evidence text here...")
83
+ ... predict_btn = gr.Button("Run Predict")
84
+ ... output_json = gr.JSON(label="Predict Results")
85
+ ... predict_btn.click(predict, inputs=[page_text_input, selected_text_input, evidence_input], outputs=output_json)
86
+ ...
87
+ ... # ---------------------------
88
+ ... # Launch
89
+ ... # ---------------------------
90
+ ... if __name__ == "__main__":
91
+ ... demo.launch(server_name="0.0.0.0")