Mohammedmarzuk17 commited on
Commit
110d1f2
·
verified ·
1 Parent(s): 5f2e5ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -82
app.py CHANGED
@@ -1,14 +1,14 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
- from sentence_transformers import SentenceTransformer, util
4
- import requests, re, datetime
5
  from concurrent.futures import ThreadPoolExecutor
 
6
 
7
  # ---------------------------
8
  # Load Models
9
  # ---------------------------
10
 
11
- # Claim Extraction → Zero-Shot Classifier
12
  claim_model_name = "MoritzLaurer/DeBERTa-v3-base-mnli"
13
  claim_classifier = pipeline(
14
  "zero-shot-classification",
@@ -17,7 +17,7 @@ claim_classifier = pipeline(
17
  )
18
  claim_labels = ["factual claim", "opinion", "personal anecdote", "other"]
19
 
20
- # AI Text Detection
21
  ai_detect_model_name = "roberta-base-openai-detector"
22
  ai_detector = pipeline(
23
  "text-classification",
@@ -25,9 +25,28 @@ ai_detector = pipeline(
25
  device=-1
26
  )
27
 
28
- # ✅ Semantic Model (CORRECT way for EmbeddingGemma)
 
 
29
  SEM_MODEL_NAME = "google/embeddinggemma-300m"
30
- sem_model = SentenceTransformer(SEM_MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # ---------------------------
33
  # Google Search Config
@@ -39,113 +58,190 @@ google_quota = {"count": 0, "date": datetime.date.today()}
39
  GOOGLE_DAILY_LIMIT = 100
40
 
41
  # ---------------------------
42
- # Helpers
43
  # ---------------------------
44
  def safe_split_text(text):
45
- pattern = r'(?<!\d)[.](?!\d)'
46
- return [s.strip() for s in re.split(pattern, text) if len(s.strip()) > 10]
 
 
 
 
 
 
 
 
47
 
48
  # ---------------------------
49
  # Claim Extraction
50
  # ---------------------------
51
- def extract_claims(text, max_claims=20):
52
- sentences = safe_split_text(text)
53
 
54
- def classify(s):
55
  out = claim_classifier(s, claim_labels)
56
- lbl = out["labels"][0]
57
- score = round(out["scores"][0], 3)
58
- return {"text": s, "label": lbl, "score": score}
59
-
60
- with ThreadPoolExecutor() as ex:
61
- results = list(ex.map(classify, sentences))
62
-
63
- return results[:max_claims]
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # ---------------------------
66
- # AI Detection
67
  # ---------------------------
68
  def detect_ai(texts):
69
  if isinstance(texts, str):
70
  texts = [texts]
71
- out = []
72
  for t in texts:
73
- r = ai_detector(t)[0]
74
- label = "AI-generated" if r["label"].lower() in ["fake", "ai-generated"] else "Human"
75
- out.append({"text": t, "label": label, "score": round(r["score"], 3)})
76
- return out
 
 
 
 
 
77
 
78
  # ---------------------------
79
- # Google + Semantic Fact Check
 
80
  # ---------------------------
81
- def fetch_google_search_semantic(claim, k=3):
82
  global google_quota
83
- if google_quota["count"] >= GOOGLE_DAILY_LIMIT:
84
- return {"keyword": [], "semantic": []}
85
-
86
- url = (
87
- "https://www.googleapis.com/customsearch/v1"
88
- f"?q={requests.utils.quote(claim)}"
89
- f"&key={GOOGLE_API_KEY}&cx={GOOGLE_CX}&num=10"
90
- )
91
-
92
- r = requests.get(url).json()
93
- google_quota["count"] += 1
94
- items = r.get("items", [])
95
-
96
- snippets = [f"{i['title']}: {i['snippet']}" for i in items]
97
- keyword_results = snippets[:k]
98
-
99
- if not snippets:
100
- return {"keyword": keyword_results, "semantic": []}
101
-
102
- q_emb = sem_model.encode(claim, normalize_embeddings=True)
103
- s_emb = sem_model.encode(snippets, normalize_embeddings=True)
104
- sims = util.cos_sim(q_emb, s_emb)[0]
105
 
106
- top_idx = sims.argsort(descending=True)[:k]
107
- semantic_results = [snippets[i] for i in top_idx]
108
 
109
- return {
110
- "keyword": keyword_results,
111
- "semantic": semantic_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  }
113
 
114
- # ---------------------------
115
- # Predict
116
- # ---------------------------
117
- def predict(text=""):
118
- if not text.strip():
119
- return {"error": "No input"}
120
-
121
- full_ai = detect_ai(text)
122
- sentences = safe_split_text(text)
123
- full_fc = {s: fetch_google_search_semantic(s) for s in sentences}
124
-
125
- claims = extract_claims(text)
126
- claim_ai = detect_ai([c["text"] for c in claims])
127
- claim_fc = {c["text"]: fetch_google_search_semantic(c["text"]) for c in claims}
128
 
129
  return {
130
  "full_text": {
131
- "input": text,
132
- "ai_detection": full_ai,
133
- "fact_checking": full_fc
134
  },
135
- "claims": claims,
136
- "claims_ai_detection": claim_ai,
137
- "claims_fact_checking": claim_fc
 
 
 
 
 
 
 
138
  }
139
 
140
  # ---------------------------
141
- # UI
142
  # ---------------------------
143
  with gr.Blocks() as demo:
144
- gr.Markdown("## EduShield AI Backend – Keyword + Semantic Fact Check")
145
- inp = gr.Textbox(lines=8, label="Input Text")
146
- btn = gr.Button("Run Analysis")
147
- out = gr.JSON()
148
- btn.click(predict, inp, out)
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  if __name__ == "__main__":
151
  demo.launch(server_name="0.0.0.0")
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoTokenizer, AutoModel
3
+ import requests, re, datetime, torch
 
4
  from concurrent.futures import ThreadPoolExecutor
5
+ import torch.nn.functional as F
6
 
7
  # ---------------------------
8
  # Load Models
9
  # ---------------------------
10
 
11
+ # Claim Extraction → Zero-Shot Classifier (DeBERTa MNLI)
12
  claim_model_name = "MoritzLaurer/DeBERTa-v3-base-mnli"
13
  claim_classifier = pipeline(
14
  "zero-shot-classification",
 
17
  )
18
  claim_labels = ["factual claim", "opinion", "personal anecdote", "other"]
19
 
20
+ # AI Text Detection → OpenAI Detector (Roberta-based)
21
  ai_detect_model_name = "roberta-base-openai-detector"
22
  ai_detector = pipeline(
23
  "text-classification",
 
25
  device=-1
26
  )
27
 
28
+ # ---------------------------
29
+ # ✅ Semantic Model (EmbeddingGemma-300M)
30
+ # ---------------------------
31
  SEM_MODEL_NAME = "google/embeddinggemma-300m"
32
+
33
+ sem_tokenizer = AutoTokenizer.from_pretrained(SEM_MODEL_NAME)
34
+ sem_model = AutoModel.from_pretrained(SEM_MODEL_NAME)
35
+ sem_model.eval()
36
+
37
+ def embed_texts(texts):
38
+ """Generate normalized sentence embeddings"""
39
+ with torch.no_grad():
40
+ inputs = sem_tokenizer(
41
+ texts,
42
+ padding=True,
43
+ truncation=True,
44
+ return_tensors="pt"
45
+ )
46
+ outputs = sem_model(**inputs)
47
+ embeddings = outputs.last_hidden_state.mean(dim=1)
48
+ embeddings = F.normalize(embeddings, p=2, dim=1)
49
+ return embeddings
50
 
51
  # ---------------------------
52
  # Google Search Config
 
58
  GOOGLE_DAILY_LIMIT = 100
59
 
60
  # ---------------------------
61
+ # Safe Split Helpers
62
  # ---------------------------
63
  def safe_split_text(text):
64
+ """
65
+ Split text safely on '.' or ',' or ';'
66
+ but do NOT split when between numbers (e.g., 1.41, 1,200).
67
+ """
68
+ pattern = r'(?<!\d)[.](?!\d)|(?<![\d\$]),(?!\d)|;'
69
+ return [
70
+ s.strip()
71
+ for s in re.split(pattern, text)
72
+ if len(s.strip().split()) > 4
73
+ ]
74
 
75
  # ---------------------------
76
  # Claim Extraction
77
  # ---------------------------
78
+ def extract_claims(page_text, max_claims=20):
79
+ sentences = safe_split_text(page_text)
80
 
81
+ def classify_sentence(s):
82
  out = claim_classifier(s, claim_labels)
83
+ label_priority = ["factual claim", "opinion", "personal anecdote"]
84
+ for lbl in label_priority:
85
+ if lbl in out["labels"]:
86
+ return {
87
+ "text": s,
88
+ "label": lbl,
89
+ "score": round(
90
+ out["scores"][out["labels"].index(lbl)], 3
91
+ )
92
+ }
93
+ return None
94
+
95
+ results = []
96
+ with ThreadPoolExecutor() as executor:
97
+ for r in executor.map(classify_sentence, sentences):
98
+ if r:
99
+ results.append(r)
100
+
101
+ results = sorted(results, key=lambda x: -len(x["text"]))[:max_claims]
102
+ return results
103
 
104
  # ---------------------------
105
+ # AI Text Detection
106
  # ---------------------------
107
  def detect_ai(texts):
108
  if isinstance(texts, str):
109
  texts = [texts]
110
+ results = []
111
  for t in texts:
112
+ out = ai_detector(t)
113
+ raw_label = out[0]["label"]
114
+ label = "AI-generated" if raw_label.lower() in ["fake", "ai-generated"] else "Human"
115
+ results.append({
116
+ "text": t,
117
+ "label": label,
118
+ "score": round(out[0]["score"], 3)
119
+ })
120
+ return results
121
 
122
  # ---------------------------
123
+ # Google Evidence Gathering
124
+ # (Keyword + Semantic Ranking)
125
  # ---------------------------
126
+ def fetch_google_search(claim, num_results=8):
127
  global google_quota
128
+ today = datetime.date.today()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ if google_quota["date"] != today:
131
+ google_quota = {"count": 0, "date": today}
132
 
133
+ if google_quota["count"] >= GOOGLE_DAILY_LIMIT:
134
+ return {
135
+ "keyword_results": ["[Google] Daily quota reached."],
136
+ "semantic_results": ["[Google] Daily quota reached."]
137
+ }
138
+
139
+ try:
140
+ url = (
141
+ "https://www.googleapis.com/customsearch/v1"
142
+ f"?q={requests.utils.quote(claim)}"
143
+ f"&key={GOOGLE_API_KEY}"
144
+ f"&cx={GOOGLE_CX}"
145
+ f"&num={num_results}"
146
+ )
147
+ r = requests.get(url).json()
148
+ google_quota["count"] += 1
149
+
150
+ items = r.get("items", [])
151
+ snippets = [
152
+ f"{item['title']}: {item['snippet']}"
153
+ for item in items
154
+ ]
155
+
156
+ # Keyword results (original behavior)
157
+ keyword_results = snippets[:3]
158
+
159
+ # Semantic ranking
160
+ if snippets:
161
+ claim_emb = embed_texts([claim])
162
+ snippet_embs = embed_texts(snippets)
163
+ sims = torch.matmul(claim_emb, snippet_embs.T)[0]
164
+ top_idx = torch.argsort(sims, descending=True)[:3]
165
+ semantic_results = [snippets[i] for i in top_idx]
166
+ else:
167
+ semantic_results = []
168
+
169
+ return {
170
+ "keyword_results": keyword_results,
171
+ "semantic_results": semantic_results
172
+ }
173
+
174
+ except Exception:
175
+ return {
176
+ "keyword_results": [],
177
+ "semantic_results": []
178
+ }
179
+
180
+ # ---------------------------
181
+ # Unified Predict Function
182
+ # ---------------------------
183
+ def predict(user_text=""):
184
+ if not user_text.strip():
185
+ return {"error": "No text provided."}
186
+
187
+ # --- Full text analysis ---
188
+ full_ai_result = detect_ai(user_text)
189
+ dot_sentences = [
190
+ s.strip() for s in user_text.split('.') if s.strip()
191
+ ]
192
+ full_fact_checking = {
193
+ s: fetch_google_search(s) for s in dot_sentences
194
  }
195
 
196
+ # --- Claim-based analysis ---
197
+ claims_data = extract_claims(user_text)
198
+ claims_texts = [c["text"] for c in claims_data]
199
+ claims_ai_results = detect_ai(claims_texts) if claims_texts else []
200
+ claims_fact_checking = {
201
+ c["text"]: fetch_google_search(c["text"])
202
+ for c in claims_data
203
+ }
 
 
 
 
 
 
204
 
205
  return {
206
  "full_text": {
207
+ "input": user_text,
208
+ "ai_detection": full_ai_result,
209
+ "fact_checking": full_fact_checking
210
  },
211
+ "claims": claims_data,
212
+ "claims_ai_detection": claims_ai_results,
213
+ "claims_fact_checking": claims_fact_checking,
214
+ "google_quota_used": google_quota["count"],
215
+ "google_quota_reset": str(
216
+ datetime.datetime.combine(
217
+ google_quota["date"] + datetime.timedelta(days=1),
218
+ datetime.time.min
219
+ )
220
+ )
221
  }
222
 
223
  # ---------------------------
224
+ # Gradio UI
225
  # ---------------------------
226
  with gr.Blocks() as demo:
227
+ gr.Markdown("## EduShield AI Backend – Keyword + Semantic Fact Checking")
 
 
 
 
228
 
229
+ page_text_input = gr.Textbox(
230
+ label="Input Text",
231
+ lines=10,
232
+ placeholder="Paste text here..."
233
+ )
234
+ predict_btn = gr.Button("Run Predict")
235
+ output_json = gr.JSON(label="Predict Results")
236
+
237
+ predict_btn.click(
238
+ predict,
239
+ inputs=[page_text_input],
240
+ outputs=output_json
241
+ )
242
+
243
+ # ---------------------------
244
+ # Launch
245
+ # ---------------------------
246
  if __name__ == "__main__":
247
  demo.launch(server_name="0.0.0.0")