Vi0509 commited on
Commit
76bf8b6
·
verified ·
1 Parent(s): a79cdd9

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +293 -0
app.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Kaeva Fact-Check API — HuggingFace Space
4
+ Two-stage pipeline:
5
+ Stage 1: DeBERTa-v3-base binary classifier (local, fast, free)
6
+ Stage 2: Gemini 2.0 Flash + Google Search grounding (cited evidence)
7
+ """
8
+
9
+ import os
10
+ import json
11
+ import time
12
+ import logging
13
+ import urllib.request
14
+ from typing import Optional
15
+
16
+ import torch
17
+ import gradio as gr
18
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
19
+
20
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
21
+ log = logging.getLogger("factcheck")
22
+
23
+ # ============================================================
24
+ # CONFIG
25
+ # ============================================================
26
+ MODEL_ID = "Vi0509/kaeva-factcheck-deberta"
27
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ CONFIDENCE_THRESHOLD = 0.65
29
+
30
+ # GCP Auth — service account JSON stored as HF Space secret
31
+ GCP_SA_JSON = os.environ.get("GCP_SERVICE_ACCOUNT_JSON", "")
32
+ GCP_PROJECT = "eastern-flight-477705-n0"
33
+ _cached_token = {"token": None, "expiry": 0}
34
+
35
+
36
+ def get_gcp_token():
37
+ """Get OAuth2 token from service account, with caching."""
38
+ import time as _time
39
+ if _cached_token["token"] and _time.time() < _cached_token["expiry"] - 60:
40
+ return _cached_token["token"]
41
+
42
+ if not GCP_SA_JSON:
43
+ return None
44
+
45
+ try:
46
+ from google.oauth2 import service_account
47
+ from google.auth.transport.requests import Request
48
+ import tempfile
49
+
50
+ # Write SA JSON to temp file
51
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
52
+ f.write(GCP_SA_JSON)
53
+ sa_path = f.name
54
+
55
+ creds = service_account.Credentials.from_service_account_file(
56
+ sa_path, scopes=["https://www.googleapis.com/auth/cloud-platform",
57
+ "https://www.googleapis.com/auth/generative-language"])
58
+ creds.refresh(Request())
59
+ os.unlink(sa_path)
60
+
61
+ _cached_token["token"] = creds.token
62
+ _cached_token["expiry"] = creds.expiry.timestamp() if creds.expiry else _time.time() + 3500
63
+ return creds.token
64
+ except Exception as e:
65
+ log.error(f"GCP auth error: {e}")
66
+ return None
67
+
68
+ # ============================================================
69
+ # STAGE 1: DeBERTa Classifier
70
+ # ============================================================
71
+ log.info(f"Loading DeBERTa model on {DEVICE}...")
72
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
73
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID).to(DEVICE)
74
+ model.eval()
75
+ log.info("DeBERTa loaded.")
76
+
77
+
78
+ def classify_claim(text: str) -> dict:
79
+ """Stage 1: Fast binary classification."""
80
+ inputs = tokenizer(text, truncation=True, max_length=256, padding="max_length", return_tensors="pt").to(DEVICE)
81
+ with torch.no_grad():
82
+ logits = model(**inputs).logits
83
+ probs = torch.softmax(logits, dim=-1)[0]
84
+
85
+ real_prob = probs[0].item()
86
+ fake_prob = probs[1].item()
87
+
88
+ return {
89
+ "label": "REAL" if real_prob > fake_prob else "FAKE",
90
+ "confidence": max(real_prob, fake_prob),
91
+ "real_score": real_prob,
92
+ "fake_score": fake_prob,
93
+ }
94
+
95
+
96
+ # ============================================================
97
+ # STAGE 2: Gemini + Google Search Grounding
98
+ # ============================================================
99
+ GEMINI_PROMPT = """You are a fact-checker. Analyze the following claim using the search results provided.
100
+
101
+ CLAIM: "{claim}"
102
+
103
+ Instructions:
104
+ 1. Determine if the claim is TRUE, FALSE, PARTIALLY TRUE, or UNVERIFIABLE
105
+ 2. Cite specific sources that support or refute the claim
106
+ 3. Provide a brief explanation (2-3 sentences)
107
+ 4. Rate your confidence (0.0 to 1.0)
108
+
109
+ Respond in this exact JSON format:
110
+ {{
111
+ "verdict": "TRUE|FALSE|PARTIALLY TRUE|UNVERIFIABLE",
112
+ "confidence": 0.0-1.0,
113
+ "explanation": "Brief explanation with evidence",
114
+ "key_finding": "One-sentence summary"
115
+ }}"""
116
+
117
+
118
+ def gemini_verify(claim: str) -> dict:
119
+ """Stage 2: Gemini with Google Search grounding via service account."""
120
+ token = get_gcp_token()
121
+ if not token:
122
+ return {"error": "GCP credentials not configured", "verdict": "UNVERIFIABLE"}
123
+
124
+ url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
125
+
126
+ payload = {
127
+ "contents": [{"parts": [{"text": GEMINI_PROMPT.format(claim=claim)}]}],
128
+ "tools": [{"googleSearch": {}}],
129
+ "generationConfig": {"temperature": 0.1, "maxOutputTokens": 1024}
130
+ }
131
+
132
+ req = urllib.request.Request(url,
133
+ data=json.dumps(payload).encode(),
134
+ headers={"Authorization": f"Bearer {token}", "Content-Type": "application/json",
135
+ "x-goog-user-project": GCP_PROJECT},
136
+ method="POST")
137
+
138
+ try:
139
+ resp = urllib.request.urlopen(req, timeout=30)
140
+ data = json.loads(resp.read())
141
+
142
+ candidate = data["candidates"][0]
143
+ text = candidate["content"]["parts"][0]["text"]
144
+
145
+ # Parse JSON from response
146
+ try:
147
+ # Strip markdown code blocks if present
148
+ clean = text.strip()
149
+ if clean.startswith("```"):
150
+ clean = clean.split("\n", 1)[1].rsplit("```", 1)[0]
151
+ result = json.loads(clean)
152
+ except json.JSONDecodeError:
153
+ result = {"verdict": "UNVERIFIABLE", "explanation": text, "confidence": 0.5}
154
+
155
+ # Extract grounding sources
156
+ grounding = candidate.get("groundingMetadata", {})
157
+ sources = []
158
+ for chunk in grounding.get("groundingChunks", []):
159
+ web = chunk.get("web", {})
160
+ if web.get("uri"):
161
+ sources.append({"title": web.get("title", ""), "url": web["uri"]})
162
+
163
+ result["sources"] = sources[:10]
164
+ result["search_queries"] = [
165
+ q.get("searchQuery", "")
166
+ for q in grounding.get("webSearchQueries", [])
167
+ ]
168
+
169
+ return result
170
+
171
+ except urllib.error.HTTPError as e:
172
+ error_body = e.read().decode()[:500]
173
+ log.error(f"Gemini API error {e.code}: {error_body}")
174
+ return {"error": f"Gemini API error {e.code}", "verdict": "UNVERIFIABLE"}
175
+ except Exception as e:
176
+ log.error(f"Gemini error: {e}")
177
+ return {"error": str(e), "verdict": "UNVERIFIABLE"}
178
+
179
+
180
+ # ============================================================
181
+ # COMBINED PIPELINE
182
+ # ============================================================
183
+ def fact_check(claim: str, force_search: bool = False) -> dict:
184
+ """Full two-stage fact-check pipeline."""
185
+ if not claim or len(claim.strip()) < 10:
186
+ return {"error": "Claim too short. Provide a meaningful statement to verify."}
187
+
188
+ start = time.time()
189
+
190
+ # Stage 1: DeBERTa
191
+ stage1 = classify_claim(claim)
192
+ result = {
193
+ "claim": claim,
194
+ "stage1_classifier": stage1,
195
+ "pipeline": "classifier_only",
196
+ "processing_time_ms": 0,
197
+ }
198
+
199
+ # Stage 2: If suspicious or low confidence, verify with Gemini
200
+ needs_verification = (
201
+ force_search or
202
+ stage1["label"] == "FAKE" or
203
+ stage1["confidence"] < CONFIDENCE_THRESHOLD
204
+ )
205
+
206
+ if needs_verification and GCP_SA_JSON:
207
+ stage2 = gemini_verify(claim)
208
+ result["stage2_gemini"] = stage2
209
+ result["pipeline"] = "classifier + gemini_search"
210
+
211
+ # Final verdict combines both stages
212
+ if stage2.get("verdict") and stage2["verdict"] != "UNVERIFIABLE":
213
+ result["final_verdict"] = stage2["verdict"]
214
+ result["final_confidence"] = stage2.get("confidence", stage1["confidence"])
215
+ else:
216
+ result["final_verdict"] = stage1["label"]
217
+ result["final_confidence"] = stage1["confidence"]
218
+ else:
219
+ result["final_verdict"] = stage1["label"]
220
+ result["final_confidence"] = stage1["confidence"]
221
+
222
+ result["processing_time_ms"] = round((time.time() - start) * 1000)
223
+ return result
224
+
225
+
226
+ # ============================================================
227
+ # GRADIO UI
228
+ # ============================================================
229
+ def gradio_check(claim: str, force_gemini: bool) -> str:
230
+ result = fact_check(claim, force_search=force_gemini)
231
+ return json.dumps(result, indent=2)
232
+
233
+
234
+ with gr.Blocks(title="Kaeva Fact-Check", theme=gr.themes.Base()) as demo:
235
+ gr.Markdown("""
236
+ # 🔍 Kaeva Fact-Check
237
+ **Two-stage AI fact-checking pipeline**
238
+ - **Stage 1:** DeBERTa classifier — instant binary detection (real vs fake)
239
+ - **Stage 2:** Gemini 2.0 Flash + Google Search — live evidence with cited sources
240
+ """)
241
+
242
+ with gr.Row():
243
+ with gr.Column(scale=3):
244
+ claim_input = gr.Textbox(
245
+ label="Enter a claim to verify",
246
+ placeholder="e.g., The Great Wall of China is visible from space.",
247
+ lines=3
248
+ )
249
+ force_search = gr.Checkbox(label="Force Google Search verification (bypass classifier)", value=False)
250
+ check_btn = gr.Button("🔍 Fact-Check", variant="primary", size="lg")
251
+
252
+ with gr.Column(scale=4):
253
+ output = gr.JSON(label="Result")
254
+
255
+ gr.Examples(
256
+ examples=[
257
+ ["The Earth is flat.", False],
258
+ ["Water boils at 100 degrees Celsius at sea level.", False],
259
+ ["COVID-19 vaccines contain microchips.", True],
260
+ ["The speed of light is approximately 300,000 km/s.", False],
261
+ ["Drinking bleach cures diseases.", True],
262
+ ],
263
+ inputs=[claim_input, force_search],
264
+ )
265
+
266
+ check_btn.click(fn=fact_check, inputs=[claim_input, force_search], outputs=output)
267
+
268
+ # ============================================================
269
+ # API ENDPOINT
270
+ # ============================================================
271
+ app = gr.mount_gradio_app(gr.routes.App(), demo, path="/")
272
+
273
+ # FastAPI additional routes
274
+ from fastapi import FastAPI
275
+ api = FastAPI()
276
+
277
+ @api.post("/api/check")
278
+ async def api_check(request: dict):
279
+ claim = request.get("claim", "")
280
+ force = request.get("force_search", False)
281
+ return fact_check(claim, force_search=force)
282
+
283
+ @api.post("/api/batch")
284
+ async def api_batch(request: dict):
285
+ claims = request.get("claims", [])
286
+ results = [fact_check(c) for c in claims[:20]] # Max 20 per batch
287
+ return {"results": results}
288
+
289
+ @api.get("/api/health")
290
+ async def health():
291
+ return {"status": "ok", "model": MODEL_ID, "device": str(DEVICE)}
292
+
293
+ demo.launch(server_name="0.0.0.0", server_port=7860)