Mohammedmarzuk17 commited on
Commit
5f2e5ba
·
verified ·
1 Parent(s): f9fb094

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -178
app.py CHANGED
@@ -1,14 +1,14 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoTokenizer, AutoModel
3
- import requests, re, datetime, torch
 
4
  from concurrent.futures import ThreadPoolExecutor
5
- import torch.nn.functional as F
6
 
7
  # ---------------------------
8
  # Load Models
9
  # ---------------------------
10
 
11
- # Claim Extraction → Zero-Shot Classifier (DeBERTa MNLI)
12
  claim_model_name = "MoritzLaurer/DeBERTa-v3-base-mnli"
13
  claim_classifier = pipeline(
14
  "zero-shot-classification",
@@ -17,7 +17,7 @@ claim_classifier = pipeline(
17
  )
18
  claim_labels = ["factual claim", "opinion", "personal anecdote", "other"]
19
 
20
- # AI Text Detection → OpenAI Detector (Roberta-based)
21
  ai_detect_model_name = "roberta-base-openai-detector"
22
  ai_detector = pipeline(
23
  "text-classification",
@@ -25,28 +25,9 @@ ai_detector = pipeline(
25
  device=-1
26
  )
27
 
28
- # ---------------------------
29
- # ✅ Semantic Model (EmbeddingGemma-300M)
30
- # ---------------------------
31
  SEM_MODEL_NAME = "google/embeddinggemma-300m"
32
-
33
- sem_tokenizer = AutoTokenizer.from_pretrained(SEM_MODEL_NAME)
34
- sem_model = AutoModel.from_pretrained(SEM_MODEL_NAME)
35
- sem_model.eval()
36
-
37
- def embed_texts(texts):
38
- """Generate normalized sentence embeddings"""
39
- with torch.no_grad():
40
- inputs = sem_tokenizer(
41
- texts,
42
- padding=True,
43
- truncation=True,
44
- return_tensors="pt"
45
- )
46
- outputs = sem_model(**inputs)
47
- embeddings = outputs.last_hidden_state.mean(dim=1)
48
- embeddings = F.normalize(embeddings, p=2, dim=1)
49
- return embeddings
50
 
51
  # ---------------------------
52
  # Google Search Config
@@ -58,190 +39,113 @@ google_quota = {"count": 0, "date": datetime.date.today()}
58
  GOOGLE_DAILY_LIMIT = 100
59
 
60
  # ---------------------------
61
- # Safe Split Helpers
62
  # ---------------------------
63
  def safe_split_text(text):
64
- """
65
- Split text safely on '.' or ',' or ';'
66
- but do NOT split when between numbers (e.g., 1.41, 1,200).
67
- """
68
- pattern = r'(?<!\d)[.](?!\d)|(?<![\d\$]),(?!\d)|;'
69
- return [
70
- s.strip()
71
- for s in re.split(pattern, text)
72
- if len(s.strip().split()) > 4
73
- ]
74
 
75
  # ---------------------------
76
  # Claim Extraction
77
  # ---------------------------
78
- def extract_claims(page_text, max_claims=20):
79
- sentences = safe_split_text(page_text)
80
 
81
- def classify_sentence(s):
82
  out = claim_classifier(s, claim_labels)
83
- label_priority = ["factual claim", "opinion", "personal anecdote"]
84
- for lbl in label_priority:
85
- if lbl in out["labels"]:
86
- return {
87
- "text": s,
88
- "label": lbl,
89
- "score": round(
90
- out["scores"][out["labels"].index(lbl)], 3
91
- )
92
- }
93
- return None
94
-
95
- results = []
96
- with ThreadPoolExecutor() as executor:
97
- for r in executor.map(classify_sentence, sentences):
98
- if r:
99
- results.append(r)
100
-
101
- results = sorted(results, key=lambda x: -len(x["text"]))[:max_claims]
102
- return results
103
 
104
  # ---------------------------
105
- # AI Text Detection
106
  # ---------------------------
107
  def detect_ai(texts):
108
  if isinstance(texts, str):
109
  texts = [texts]
110
- results = []
111
  for t in texts:
112
- out = ai_detector(t)
113
- raw_label = out[0]["label"]
114
- label = "AI-generated" if raw_label.lower() in ["fake", "ai-generated"] else "Human"
115
- results.append({
116
- "text": t,
117
- "label": label,
118
- "score": round(out[0]["score"], 3)
119
- })
120
- return results
121
 
122
  # ---------------------------
123
- # Google Evidence Gathering
124
- # (Keyword + Semantic Ranking)
125
  # ---------------------------
126
- def fetch_google_search(claim, num_results=8):
127
  global google_quota
128
- today = datetime.date.today()
 
129
 
130
- if google_quota["date"] != today:
131
- google_quota = {"count": 0, "date": today}
 
 
 
132
 
133
- if google_quota["count"] >= GOOGLE_DAILY_LIMIT:
134
- return {
135
- "keyword_results": ["[Google] Daily quota reached."],
136
- "semantic_results": ["[Google] Daily quota reached."]
137
- }
138
-
139
- try:
140
- url = (
141
- "https://www.googleapis.com/customsearch/v1"
142
- f"?q={requests.utils.quote(claim)}"
143
- f"&key={GOOGLE_API_KEY}"
144
- f"&cx={GOOGLE_CX}"
145
- f"&num={num_results}"
146
- )
147
- r = requests.get(url).json()
148
- google_quota["count"] += 1
149
-
150
- items = r.get("items", [])
151
- snippets = [
152
- f"{item['title']}: {item['snippet']}"
153
- for item in items
154
- ]
155
-
156
- # Keyword results (original behavior)
157
- keyword_results = snippets[:3]
158
-
159
- # Semantic ranking
160
- if snippets:
161
- claim_emb = embed_texts([claim])
162
- snippet_embs = embed_texts(snippets)
163
- sims = torch.matmul(claim_emb, snippet_embs.T)[0]
164
- top_idx = torch.argsort(sims, descending=True)[:3]
165
- semantic_results = [snippets[i] for i in top_idx]
166
- else:
167
- semantic_results = []
168
-
169
- return {
170
- "keyword_results": keyword_results,
171
- "semantic_results": semantic_results
172
- }
173
-
174
- except Exception:
175
- return {
176
- "keyword_results": [],
177
- "semantic_results": []
178
- }
179
-
180
- # ---------------------------
181
- # Unified Predict Function
182
- # ---------------------------
183
- def predict(user_text=""):
184
- if not user_text.strip():
185
- return {"error": "No text provided."}
186
-
187
- # --- Full text analysis ---
188
- full_ai_result = detect_ai(user_text)
189
- dot_sentences = [
190
- s.strip() for s in user_text.split('.') if s.strip()
191
- ]
192
- full_fact_checking = {
193
- s: fetch_google_search(s) for s in dot_sentences
194
- }
195
 
196
- # --- Claim-based analysis ---
197
- claims_data = extract_claims(user_text)
198
- claims_texts = [c["text"] for c in claims_data]
199
- claims_ai_results = detect_ai(claims_texts) if claims_texts else []
200
- claims_fact_checking = {
201
- c["text"]: fetch_google_search(c["text"])
202
- for c in claims_data
203
- }
 
 
 
 
204
 
205
  return {
206
- "full_text": {
207
- "input": user_text,
208
- "ai_detection": full_ai_result,
209
- "fact_checking": full_fact_checking
210
- },
211
- "claims": claims_data,
212
- "claims_ai_detection": claims_ai_results,
213
- "claims_fact_checking": claims_fact_checking,
214
- "google_quota_used": google_quota["count"],
215
- "google_quota_reset": str(
216
- datetime.datetime.combine(
217
- google_quota["date"] + datetime.timedelta(days=1),
218
- datetime.time.min
219
- )
220
- )
221
  }
222
 
223
  # ---------------------------
224
- # Gradio UI
225
  # ---------------------------
226
- with gr.Blocks() as demo:
227
- gr.Markdown("## EduShield AI Backend – Keyword + Semantic Fact Checking")
 
228
 
229
- page_text_input = gr.Textbox(
230
- label="Input Text",
231
- lines=10,
232
- placeholder="Paste text here..."
233
- )
234
- predict_btn = gr.Button("Run Predict")
235
- output_json = gr.JSON(label="Predict Results")
236
 
237
- predict_btn.click(
238
- predict,
239
- inputs=[page_text_input],
240
- outputs=output_json
241
- )
 
 
 
 
 
 
 
 
 
242
 
243
  # ---------------------------
244
- # Launch
245
  # ---------------------------
 
 
 
 
 
 
 
246
  if __name__ == "__main__":
247
  demo.launch(server_name="0.0.0.0")
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
+ from sentence_transformers import SentenceTransformer, util
4
+ import requests, re, datetime
5
  from concurrent.futures import ThreadPoolExecutor
 
6
 
7
  # ---------------------------
8
  # Load Models
9
  # ---------------------------
10
 
11
+ # Claim Extraction → Zero-Shot Classifier
12
  claim_model_name = "MoritzLaurer/DeBERTa-v3-base-mnli"
13
  claim_classifier = pipeline(
14
  "zero-shot-classification",
 
17
  )
18
  claim_labels = ["factual claim", "opinion", "personal anecdote", "other"]
19
 
20
+ # AI Text Detection
21
  ai_detect_model_name = "roberta-base-openai-detector"
22
  ai_detector = pipeline(
23
  "text-classification",
 
25
  device=-1
26
  )
27
 
28
+ # ✅ Semantic Model (CORRECT way for EmbeddingGemma)
 
 
29
  SEM_MODEL_NAME = "google/embeddinggemma-300m"
30
+ sem_model = SentenceTransformer(SEM_MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # ---------------------------
33
  # Google Search Config
 
39
  GOOGLE_DAILY_LIMIT = 100
40
 
41
  # ---------------------------
42
+ # Helpers
43
  # ---------------------------
44
  def safe_split_text(text):
45
+ pattern = r'(?<!\d)[.](?!\d)'
46
+ return [s.strip() for s in re.split(pattern, text) if len(s.strip()) > 10]
 
 
 
 
 
 
 
 
47
 
48
  # ---------------------------
49
  # Claim Extraction
50
  # ---------------------------
51
+ def extract_claims(text, max_claims=20):
52
+ sentences = safe_split_text(text)
53
 
54
+ def classify(s):
55
  out = claim_classifier(s, claim_labels)
56
+ lbl = out["labels"][0]
57
+ score = round(out["scores"][0], 3)
58
+ return {"text": s, "label": lbl, "score": score}
59
+
60
+ with ThreadPoolExecutor() as ex:
61
+ results = list(ex.map(classify, sentences))
62
+
63
+ return results[:max_claims]
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # ---------------------------
66
+ # AI Detection
67
  # ---------------------------
68
  def detect_ai(texts):
69
  if isinstance(texts, str):
70
  texts = [texts]
71
+ out = []
72
  for t in texts:
73
+ r = ai_detector(t)[0]
74
+ label = "AI-generated" if r["label"].lower() in ["fake", "ai-generated"] else "Human"
75
+ out.append({"text": t, "label": label, "score": round(r["score"], 3)})
76
+ return out
 
 
 
 
 
77
 
78
  # ---------------------------
79
+ # Google + Semantic Fact Check
 
80
  # ---------------------------
81
+ def fetch_google_search_semantic(claim, k=3):
82
  global google_quota
83
+ if google_quota["count"] >= GOOGLE_DAILY_LIMIT:
84
+ return {"keyword": [], "semantic": []}
85
 
86
+ url = (
87
+ "https://www.googleapis.com/customsearch/v1"
88
+ f"?q={requests.utils.quote(claim)}"
89
+ f"&key={GOOGLE_API_KEY}&cx={GOOGLE_CX}&num=10"
90
+ )
91
 
92
+ r = requests.get(url).json()
93
+ google_quota["count"] += 1
94
+ items = r.get("items", [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ snippets = [f"{i['title']}: {i['snippet']}" for i in items]
97
+ keyword_results = snippets[:k]
98
+
99
+ if not snippets:
100
+ return {"keyword": keyword_results, "semantic": []}
101
+
102
+ q_emb = sem_model.encode(claim, normalize_embeddings=True)
103
+ s_emb = sem_model.encode(snippets, normalize_embeddings=True)
104
+ sims = util.cos_sim(q_emb, s_emb)[0]
105
+
106
+ top_idx = sims.argsort(descending=True)[:k]
107
+ semantic_results = [snippets[i] for i in top_idx]
108
 
109
  return {
110
+ "keyword": keyword_results,
111
+ "semantic": semantic_results
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  }
113
 
114
  # ---------------------------
115
+ # Predict
116
  # ---------------------------
117
+ def predict(text=""):
118
+ if not text.strip():
119
+ return {"error": "No input"}
120
 
121
+ full_ai = detect_ai(text)
122
+ sentences = safe_split_text(text)
123
+ full_fc = {s: fetch_google_search_semantic(s) for s in sentences}
 
 
 
 
124
 
125
+ claims = extract_claims(text)
126
+ claim_ai = detect_ai([c["text"] for c in claims])
127
+ claim_fc = {c["text"]: fetch_google_search_semantic(c["text"]) for c in claims}
128
+
129
+ return {
130
+ "full_text": {
131
+ "input": text,
132
+ "ai_detection": full_ai,
133
+ "fact_checking": full_fc
134
+ },
135
+ "claims": claims,
136
+ "claims_ai_detection": claim_ai,
137
+ "claims_fact_checking": claim_fc
138
+ }
139
 
140
  # ---------------------------
141
+ # UI
142
  # ---------------------------
143
+ with gr.Blocks() as demo:
144
+ gr.Markdown("## EduShield AI Backend – Keyword + Semantic Fact Check")
145
+ inp = gr.Textbox(lines=8, label="Input Text")
146
+ btn = gr.Button("Run Analysis")
147
+ out = gr.JSON()
148
+ btn.click(predict, inp, out)
149
+
150
  if __name__ == "__main__":
151
  demo.launch(server_name="0.0.0.0")