Update app.py
Browse files
app.py
CHANGED
|
@@ -1,28 +1,40 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import pipeline
|
|
|
|
|
|
|
| 3 |
|
| 4 |
# ---------------------------
|
| 5 |
# Load Models
|
| 6 |
# ---------------------------
|
| 7 |
-
# Claim Extraction → Zero-Shot Classifier
|
| 8 |
-
# (using MoritzLaurer public DeBERTa MNLI model)
|
| 9 |
claim_model_name = "MoritzLaurer/DeBERTa-v3-base-mnli"
|
| 10 |
claim_classifier = pipeline("zero-shot-classification", model=claim_model_name, device=-1)
|
| 11 |
claim_labels = ["factual claim", "opinion", "personal anecdote", "other"]
|
| 12 |
|
| 13 |
-
# AI Text Detection → OpenAI Detector (Roberta-based)
|
| 14 |
ai_detect_model_name = "roberta-base-openai-detector"
|
| 15 |
ai_detector = pipeline("text-classification", model=ai_detect_model_name, device=-1)
|
| 16 |
|
| 17 |
-
# Fact-Checking (NLI) → DistilBART MNLI
|
| 18 |
nli_model_name = "valhalla/distilbart-mnli-12-3"
|
| 19 |
nli_pipeline = pipeline("text-classification", model=nli_model_name, tokenizer=nli_model_name, device=-1)
|
| 20 |
|
| 21 |
# ---------------------------
|
| 22 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# ---------------------------
|
| 24 |
def extract_claims(page_text):
|
| 25 |
-
"""Extract top 5 factual claims from page text."""
|
| 26 |
sentences = [s.strip() for s in page_text.split(".") if len(s.strip()) > 5]
|
| 27 |
results = []
|
| 28 |
for s in sentences:
|
|
@@ -32,8 +44,10 @@ def extract_claims(page_text):
|
|
| 32 |
return results[:5]
|
| 33 |
|
| 34 |
|
|
|
|
|
|
|
|
|
|
| 35 |
def detect_ai(texts):
|
| 36 |
-
"""Detect whether input text is AI-generated or human-written."""
|
| 37 |
if isinstance(texts, str):
|
| 38 |
texts = [texts]
|
| 39 |
results = []
|
|
@@ -43,62 +57,104 @@ def detect_ai(texts):
|
|
| 43 |
return results
|
| 44 |
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
results = []
|
| 51 |
for c in claims:
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
return results
|
| 55 |
|
| 56 |
|
| 57 |
# ---------------------------
|
| 58 |
-
# Unified Predict
|
| 59 |
# ---------------------------
|
| 60 |
-
def predict(page_text=""
|
| 61 |
-
"""
|
| 62 |
-
1. Extract top 5 claims from page_text
|
| 63 |
-
2. Run AI Detection on claims + selected_text
|
| 64 |
-
3. Run Fact-Checking on claims + evidence_text if provided
|
| 65 |
-
"""
|
| 66 |
-
# Extract claims
|
| 67 |
claims = extract_claims(page_text) if page_text else []
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
|
| 70 |
-
ai_input = claims.copy()
|
| 71 |
-
if selected_text:
|
| 72 |
-
ai_input.append(selected_text)
|
| 73 |
-
ai_results = detect_ai(ai_input) if ai_input else []
|
| 74 |
-
|
| 75 |
-
# Fact-checking: only if evidence is provided
|
| 76 |
-
fc_results = fact_check(claims + ([selected_text] if selected_text else []), evidence_text) if evidence_text else []
|
| 77 |
-
|
| 78 |
-
return {
|
| 79 |
-
"claims": claims,
|
| 80 |
-
"ai_detection": ai_results,
|
| 81 |
-
"fact_checking": fc_results
|
| 82 |
-
}
|
| 83 |
|
| 84 |
|
| 85 |
# ---------------------------
|
| 86 |
# Gradio UI
|
| 87 |
# ---------------------------
|
| 88 |
with gr.Blocks() as demo:
|
| 89 |
-
gr.Markdown("## EduShield AI Backend - Predict API & UI")
|
| 90 |
|
| 91 |
page_text_input = gr.Textbox(label="Full Page Text", lines=10, placeholder="Paste page text here...")
|
| 92 |
-
selected_text_input = gr.Textbox(label="Selected Text", lines=5, placeholder="Paste selected text here...")
|
| 93 |
-
evidence_input = gr.Textbox(label="Evidence Text", lines=5, placeholder="Paste evidence text here...")
|
| 94 |
predict_btn = gr.Button("Run Predict")
|
| 95 |
output_json = gr.JSON(label="Predict Results")
|
| 96 |
|
| 97 |
-
predict_btn.click(
|
| 98 |
-
predict,
|
| 99 |
-
inputs=[page_text_input, selected_text_input, evidence_input],
|
| 100 |
-
outputs=output_json
|
| 101 |
-
)
|
| 102 |
|
| 103 |
# ---------------------------
|
| 104 |
# Launch
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import pipeline
|
| 3 |
+
import feedparser, requests, re, wikipedia, time
|
| 4 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 5 |
|
| 6 |
# ---------------------------
|
| 7 |
# Load Models
|
| 8 |
# ---------------------------
|
|
|
|
|
|
|
| 9 |
claim_model_name = "MoritzLaurer/DeBERTa-v3-base-mnli"
|
| 10 |
claim_classifier = pipeline("zero-shot-classification", model=claim_model_name, device=-1)
|
| 11 |
claim_labels = ["factual claim", "opinion", "personal anecdote", "other"]
|
| 12 |
|
|
|
|
| 13 |
ai_detect_model_name = "roberta-base-openai-detector"
|
| 14 |
ai_detector = pipeline("text-classification", model=ai_detect_model_name, device=-1)
|
| 15 |
|
|
|
|
| 16 |
nli_model_name = "valhalla/distilbart-mnli-12-3"
|
| 17 |
nli_pipeline = pipeline("text-classification", model=nli_model_name, tokenizer=nli_model_name, device=-1)
|
| 18 |
|
| 19 |
# ---------------------------
|
| 20 |
+
# Fact-check sources
|
| 21 |
+
# ---------------------------
|
| 22 |
+
FACT_FEEDS = {
|
| 23 |
+
"Snopes": "https://www.snopes.com/feed/",
|
| 24 |
+
"PolitiFact": "https://www.politifact.com/rss/factchecks/",
|
| 25 |
+
"FactCheck.org": "https://www.factcheck.org/feed/",
|
| 26 |
+
"AP News Fact Check": "https://apnews.com/hub/ap-fact-check.rss",
|
| 27 |
+
"Reuters Fact Check": "https://www.reuters.com/fact-check/rss",
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
GOOGLE_API_KEY = "AIzaSyAC56onKwR17zd_djUPEfGXQACy9qRjDxw"
|
| 31 |
+
GOOGLE_CX = "YOUR_SEARCH_ENGINE_ID" # you need to set up a CSE at Google
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ---------------------------
|
| 35 |
+
# Claim Extraction
|
| 36 |
# ---------------------------
|
| 37 |
def extract_claims(page_text):
|
|
|
|
| 38 |
sentences = [s.strip() for s in page_text.split(".") if len(s.strip()) > 5]
|
| 39 |
results = []
|
| 40 |
for s in sentences:
|
|
|
|
| 44 |
return results[:5]
|
| 45 |
|
| 46 |
|
| 47 |
+
# ---------------------------
|
| 48 |
+
# AI Detection
|
| 49 |
+
# ---------------------------
|
| 50 |
def detect_ai(texts):
|
|
|
|
| 51 |
if isinstance(texts, str):
|
| 52 |
texts = [texts]
|
| 53 |
results = []
|
|
|
|
| 57 |
return results
|
| 58 |
|
| 59 |
|
| 60 |
+
# ---------------------------
|
| 61 |
+
# Evidence Fetchers
|
| 62 |
+
# ---------------------------
|
| 63 |
+
def fetch_rss_evidence(claim):
|
| 64 |
+
evidence = []
|
| 65 |
+
for name, url in FACT_FEEDS.items():
|
| 66 |
+
try:
|
| 67 |
+
feed = feedparser.parse(url)
|
| 68 |
+
for entry in feed.entries[:10]:
|
| 69 |
+
if re.search(claim[:30], entry.title + " " + entry.get("summary", ""), re.I):
|
| 70 |
+
evidence.append(f"[{name}] {entry.title} - {entry.link}")
|
| 71 |
+
except Exception:
|
| 72 |
+
continue
|
| 73 |
+
return evidence
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def fetch_wikipedia(claim):
|
| 77 |
+
try:
|
| 78 |
+
results = wikipedia.search(claim)
|
| 79 |
+
if results:
|
| 80 |
+
page = wikipedia.page(results[0])
|
| 81 |
+
return [f"[Wikipedia] {page.title}: {page.summary[:300]}..."]
|
| 82 |
+
except Exception:
|
| 83 |
+
return []
|
| 84 |
+
return []
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def fetch_google_cse(claim):
|
| 88 |
+
try:
|
| 89 |
+
url = (
|
| 90 |
+
f"https://www.googleapis.com/customsearch/v1?q={requests.utils.quote(claim)}"
|
| 91 |
+
f"&key={GOOGLE_API_KEY}&cx={GOOGLE_CX}"
|
| 92 |
+
)
|
| 93 |
+
r = requests.get(url).json()
|
| 94 |
+
if "items" in r:
|
| 95 |
+
return [f"[Google] {item['title']} - {item['link']}" for item in r["items"][:3]]
|
| 96 |
+
except Exception:
|
| 97 |
+
return []
|
| 98 |
+
return []
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def gather_evidence(claim):
|
| 102 |
+
evidence = []
|
| 103 |
+
with ThreadPoolExecutor() as ex:
|
| 104 |
+
futures = [
|
| 105 |
+
ex.submit(fetch_rss_evidence, claim),
|
| 106 |
+
ex.submit(fetch_wikipedia, claim),
|
| 107 |
+
ex.submit(fetch_google_cse, claim),
|
| 108 |
+
]
|
| 109 |
+
for f in futures:
|
| 110 |
+
try:
|
| 111 |
+
evidence.extend(f.result())
|
| 112 |
+
except:
|
| 113 |
+
continue
|
| 114 |
+
return evidence[:5]
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ---------------------------
|
| 118 |
+
# Fact-Checking
|
| 119 |
+
# ---------------------------
|
| 120 |
+
def fact_check(claims):
|
| 121 |
results = []
|
| 122 |
for c in claims:
|
| 123 |
+
evidence_list = gather_evidence(c)
|
| 124 |
+
if not evidence_list:
|
| 125 |
+
results.append({"claim": c, "label": "no evidence found", "score": 0.0})
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
best_ev = evidence_list[0]
|
| 129 |
+
out = nli_pipeline(hypothesis=c, sequence_pair=best_ev)
|
| 130 |
+
results.append(
|
| 131 |
+
{"claim": c, "evidence": best_ev, "label": out[0]["label"], "score": round(out[0]["score"], 3)}
|
| 132 |
+
)
|
| 133 |
return results
|
| 134 |
|
| 135 |
|
| 136 |
# ---------------------------
|
| 137 |
+
# Unified Predict
|
| 138 |
# ---------------------------
|
| 139 |
+
def predict(page_text=""):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
claims = extract_claims(page_text) if page_text else []
|
| 141 |
+
ai_results = detect_ai(claims) if claims else []
|
| 142 |
+
fc_results = fact_check(claims) if claims else []
|
| 143 |
|
| 144 |
+
return {"claims": claims, "ai_detection": ai_results, "fact_checking": fc_results}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
|
| 147 |
# ---------------------------
|
| 148 |
# Gradio UI
|
| 149 |
# ---------------------------
|
| 150 |
with gr.Blocks() as demo:
|
| 151 |
+
gr.Markdown("## EduShield AI Backend - Predict API & UI (with Fact-Check Sources)")
|
| 152 |
|
| 153 |
page_text_input = gr.Textbox(label="Full Page Text", lines=10, placeholder="Paste page text here...")
|
|
|
|
|
|
|
| 154 |
predict_btn = gr.Button("Run Predict")
|
| 155 |
output_json = gr.JSON(label="Predict Results")
|
| 156 |
|
| 157 |
+
predict_btn.click(predict, inputs=[page_text_input], outputs=output_json)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
# ---------------------------
|
| 160 |
# Launch
|