Karthik Barma commited on
Commit
4ec030d
·
1 Parent(s): 678eb81

Add session persistence, sharing, web search grounding

Browse files
Files changed (2) hide show
  1. agents/research.py +102 -12
  2. api/routes.py +139 -9
agents/research.py CHANGED
@@ -2,12 +2,19 @@
2
 
3
  Constitutional role: Strictly neutral. Authorship is hidden from adversarial agents.
4
  The research package is the ONLY shared context between Prosecutor and Defense.
 
 
 
 
5
  """
6
 
 
7
  import json
8
  import logging
 
9
  from typing import Callable, Optional
10
 
 
11
  from langchain_groq import ChatGroq
12
  from langchain_core.messages import SystemMessage, HumanMessage
13
 
@@ -16,6 +23,71 @@ from models.schemas import StreamEvent
16
 
17
  logger = logging.getLogger(__name__)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  RESEARCH_SYSTEM_PROMPT = """You are a neutral research analyst producing anonymous briefing material for an adversarial review process.
20
 
21
  CONSTITUTIONAL DIRECTIVE: Remain strictly neutral. Do not advocate for or against any outcome.
@@ -98,24 +170,42 @@ class ResearchAgent:
98
  ]
99
 
100
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  thinking_phases = [
102
  "Scanning market landscape and competitive environment...",
103
- "Gathering relevant data points and statistics...",
104
- "Analyzing historical precedents and case studies...",
105
- "Identifying key stakeholders and risk factors...",
106
- "Compiling research synthesis...",
107
  ]
108
 
109
  for phase in thinking_phases:
110
  if stream_callback:
111
- await stream_callback(
112
- StreamEvent(
113
- event_type="research_start",
114
- agent="research",
115
- content=phase + "\n",
116
- )
117
- )
118
- import asyncio
119
  await asyncio.sleep(0.3)
120
 
121
  response = await self.llm.ainvoke(messages)
 
2
 
3
  Constitutional role: Strictly neutral. Authorship is hidden from adversarial agents.
4
  The research package is the ONLY shared context between Prosecutor and Defense.
5
+
6
+ Grounding: When available, the agent performs lightweight web retrieval via
7
+ DuckDuckGo Instant Answers and (optionally) Tavily Search API to ground
8
+ claims in current factual data rather than relying solely on LLM training data.
9
  """
10
 
11
+ import asyncio
12
  import json
13
  import logging
14
+ import os
15
  from typing import Callable, Optional
16
 
17
+ import httpx
18
  from langchain_groq import ChatGroq
19
  from langchain_core.messages import SystemMessage, HumanMessage
20
 
 
23
 
24
  logger = logging.getLogger(__name__)
25
 
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Lightweight web retrieval for factual grounding
29
+ # ---------------------------------------------------------------------------
30
+
31
+ async def _web_search_grounding(query: str, max_results: int = 3) -> list[str]:
32
+ """Retrieve web search snippets to ground LLM research in current facts.
33
+
34
+ Strategy:
35
+ 1. Try Tavily Search API if TAVILY_API_KEY is set (best quality).
36
+ 2. Fall back to DuckDuckGo Instant Answers API (no key required).
37
+ 3. Return empty list on failure — research proceeds with LLM-only.
38
+ """
39
+ snippets = []
40
+
41
+ # Try Tavily first (higher quality, structured results)
42
+ tavily_key = os.getenv("TAVILY_API_KEY")
43
+ if tavily_key:
44
+ try:
45
+ async with httpx.AsyncClient(timeout=8.0) as client:
46
+ resp = await client.post(
47
+ "https://api.tavily.com/search",
48
+ json={
49
+ "api_key": tavily_key,
50
+ "query": query,
51
+ "max_results": max_results,
52
+ "search_depth": "basic",
53
+ },
54
+ )
55
+ if resp.status_code == 200:
56
+ data = resp.json()
57
+ for r in data.get("results", [])[:max_results]:
58
+ snippet = r.get("content", "")[:300]
59
+ source = r.get("url", "")
60
+ if snippet:
61
+ snippets.append(f"{snippet} (source: {source})")
62
+ if snippets:
63
+ logger.info("Tavily grounding: %d snippets for '%s'", len(snippets), query[:50])
64
+ return snippets
65
+ except Exception as e:
66
+ logger.warning("Tavily search failed, falling back to DuckDuckGo: %s", e)
67
+
68
+ # Fallback: DuckDuckGo Instant Answers (no API key needed)
69
+ try:
70
+ async with httpx.AsyncClient(timeout=5.0) as client:
71
+ resp = await client.get(
72
+ "https://api.duckduckgo.com/",
73
+ params={"q": query, "format": "json", "no_html": "1", "skip_disambig": "1"},
74
+ )
75
+ if resp.status_code == 200:
76
+ data = resp.json()
77
+ # Abstract text (Wikipedia-sourced summary)
78
+ if data.get("AbstractText"):
79
+ snippets.append(data["AbstractText"][:400])
80
+ # Related topics
81
+ for topic in data.get("RelatedTopics", [])[:max_results]:
82
+ if isinstance(topic, dict) and topic.get("Text"):
83
+ snippets.append(topic["Text"][:200])
84
+ if snippets:
85
+ logger.info("DuckDuckGo grounding: %d snippets for '%s'", len(snippets), query[:50])
86
+ except Exception as e:
87
+ logger.warning("DuckDuckGo search failed: %s", e)
88
+
89
+ return snippets
90
+
91
  RESEARCH_SYSTEM_PROMPT = """You are a neutral research analyst producing anonymous briefing material for an adversarial review process.
92
 
93
  CONSTITUTIONAL DIRECTIVE: Remain strictly neutral. Do not advocate for or against any outcome.
 
170
  ]
171
 
172
  try:
173
+ # Phase 1: Web retrieval for factual grounding
174
+ if stream_callback:
175
+ await stream_callback(StreamEvent(
176
+ event_type="research_start", agent="research",
177
+ content="Retrieving current web data for factual grounding...\n",
178
+ ))
179
+
180
+ web_snippets = await _web_search_grounding(decision_question)
181
+
182
+ if web_snippets and stream_callback:
183
+ await stream_callback(StreamEvent(
184
+ event_type="research_start", agent="research",
185
+ content=f"Found {len(web_snippets)} grounding sources. Analyzing...\n",
186
+ ))
187
+
188
+ # Inject web grounding into the prompt so LLM has current facts
189
+ if web_snippets:
190
+ grounding_text = "\n".join(f"- {s}" for s in web_snippets)
191
+ messages.append(HumanMessage(
192
+ content=f"Web-retrieved grounding data (use these current facts to supplement your analysis):\n{grounding_text}"
193
+ ))
194
+
195
+ # Phase 2: LLM analysis with grounding context
196
  thinking_phases = [
197
  "Scanning market landscape and competitive environment...",
198
+ "Analyzing data points and historical precedents...",
199
+ "Identifying stakeholders and risk factors...",
200
+ "Compiling grounded research synthesis...",
 
201
  ]
202
 
203
  for phase in thinking_phases:
204
  if stream_callback:
205
+ await stream_callback(StreamEvent(
206
+ event_type="research_start", agent="research",
207
+ content=phase + "\n",
208
+ ))
 
 
 
 
209
  await asyncio.sleep(0.3)
210
 
211
  response = await self.llm.ainvoke(messages)
api/routes.py CHANGED
@@ -1,10 +1,13 @@
1
  """FastAPI routes for the Verdict API — REST + WebSocket endpoints."""
2
 
3
  import asyncio
 
4
  import json
5
  import logging
 
6
  import uuid
7
  from datetime import datetime
 
8
  from typing import Optional, Literal
9
 
10
  from fastapi import APIRouter, WebSocket, WebSocketDisconnect, HTTPException
@@ -20,9 +23,59 @@ logger = logging.getLogger(__name__)
20
 
21
  router = APIRouter(prefix="/api/verdict")
22
 
23
- # In-memory session store
 
 
 
 
 
 
 
24
  sessions: dict[str, dict] = {}
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # Valid output formats and their descriptions
27
  OUTPUT_FORMATS = {
28
  "executive": "High-level summary for executives and decision-makers",
@@ -163,6 +216,7 @@ async def start_verdict(request: StartRequest):
163
  "created_at": datetime.utcnow().isoformat(),
164
  }
165
  sessions[decision.id] = session
 
166
 
167
  logger.info(
168
  "Session created: %s — %s [format=%s, domain=%s]",
@@ -218,7 +272,7 @@ async def get_history():
218
  @router.get("/{session_id}/status")
219
  async def get_status(session_id: str):
220
  """Get the current status of a verdict session."""
221
- session = sessions.get(session_id)
222
  if not session:
223
  raise HTTPException(status_code=404, detail="Session not found")
224
 
@@ -233,7 +287,7 @@ async def get_status(session_id: str):
233
  @router.get("/{session_id}/result")
234
  async def get_result(session_id: str):
235
  """Get the complete result of a verdict session."""
236
- session = sessions.get(session_id)
237
  if not session:
238
  raise HTTPException(status_code=404, detail="Session not found")
239
 
@@ -257,7 +311,7 @@ async def stream_verdict(websocket: WebSocket, session_id: str):
257
  await websocket.accept()
258
  logger.info("WebSocket connected: %s", session_id)
259
 
260
- session = sessions.get(session_id)
261
  if not session:
262
  await websocket.send_json({"error": "Session not found"})
263
  await websocket.close()
@@ -298,6 +352,7 @@ async def stream_verdict(websocket: WebSocket, session_id: str):
298
  }
299
 
300
  session["status"] = "complete"
 
301
  except Exception as e:
302
  logger.error("Pipeline failed: %s", str(e))
303
  session["status"] = "error"
@@ -348,7 +403,7 @@ async def stream_verdict(websocket: WebSocket, session_id: str):
348
  @router.get("/{session_id}/export/markdown")
349
  async def export_markdown(session_id: str):
350
  """Export the verdict session as a markdown report."""
351
- session = sessions.get(session_id)
352
  if not session:
353
  raise HTTPException(status_code=404, detail="Session not found")
354
  if not session.get("result"):
@@ -361,7 +416,7 @@ async def export_markdown(session_id: str):
361
  @router.get("/{session_id}/export/json")
362
  async def export_json(session_id: str):
363
  """Export the verdict session as structured JSON."""
364
- session = sessions.get(session_id)
365
  if not session:
366
  raise HTTPException(status_code=404, detail="Session not found")
367
  if not session.get("result"):
@@ -374,7 +429,7 @@ async def export_json(session_id: str):
374
  @router.get("/{session_id}/export/pdf")
375
  async def export_pdf(session_id: str):
376
  """Export the verdict session as a formatted PDF report."""
377
- session = sessions.get(session_id)
378
  if not session:
379
  raise HTTPException(status_code=404, detail="Session not found")
380
  if not session.get("result"):
@@ -391,7 +446,7 @@ async def export_pdf(session_id: str):
391
  @router.get("/{session_id}/export/docx")
392
  async def export_docx(session_id: str):
393
  """Export the verdict session as a formatted DOCX report."""
394
- session = sessions.get(session_id)
395
  if not session:
396
  raise HTTPException(status_code=404, detail="Session not found")
397
  if not session.get("result"):
@@ -417,7 +472,7 @@ class FollowUpRequest(BaseModel):
417
  @router.post("/{session_id}/followup")
418
  async def followup_question(session_id: str, request: FollowUpRequest):
419
  """Ask a follow-up question about the verdict session."""
420
- session = sessions.get(session_id)
421
  if not session:
422
  raise HTTPException(status_code=404, detail="Session not found")
423
  if not session.get("result"):
@@ -467,3 +522,78 @@ Tailor your response to the {output_format} format and {domain} domain."""),
467
  return {"answer": response.content, "session_id": session_id}
468
  except Exception as e:
469
  raise HTTPException(status_code=500, detail=f"Failed to generate follow-up: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """FastAPI routes for the Verdict API — REST + WebSocket endpoints."""
2
 
3
  import asyncio
4
+ import hashlib
5
  import json
6
  import logging
7
+ import os
8
  import uuid
9
  from datetime import datetime
10
+ from pathlib import Path
11
  from typing import Optional, Literal
12
 
13
  from fastapi import APIRouter, WebSocket, WebSocketDisconnect, HTTPException
 
23
 
24
  router = APIRouter(prefix="/api/verdict")
25
 
26
+ # ---------------------------------------------------------------------------
27
+ # Session persistence — JSON file store with in-memory cache
28
+ # ---------------------------------------------------------------------------
29
+
30
+ SESSION_DIR = Path(os.getenv("SESSION_DIR", "data/sessions"))
31
+ SESSION_DIR.mkdir(parents=True, exist_ok=True)
32
+
33
+ # In-memory session cache backed by JSON file persistence
34
  sessions: dict[str, dict] = {}
35
 
36
+
37
+ def _persist_session(session_id: str, session: dict) -> None:
38
+ """Write session to disk as JSON for persistence across restarts."""
39
+ try:
40
+ path = SESSION_DIR / f"{session_id}.json"
41
+ with open(path, "w") as f:
42
+ json.dump(session, f, default=str)
43
+ except Exception as e:
44
+ logger.warning("Failed to persist session %s: %s", session_id, e)
45
+
46
+
47
+ def _load_session(session_id: str) -> Optional[dict]:
48
+ """Load a session from disk if not in memory cache."""
49
+ path = SESSION_DIR / f"{session_id}.json"
50
+ if path.exists():
51
+ try:
52
+ with open(path) as f:
53
+ session = json.load(f)
54
+ sessions[session_id] = session # Warm the cache
55
+ return session
56
+ except Exception as e:
57
+ logger.warning("Failed to load session %s: %s", session_id, e)
58
+ return None
59
+
60
+
61
+ def _get_session(session_id: str) -> Optional[dict]:
62
+ """Get session from cache or disk."""
63
+ if session_id in sessions:
64
+ return sessions[session_id]
65
+ return _load_session(session_id)
66
+
67
+
68
+ def _load_all_sessions() -> None:
69
+ """Load all persisted sessions into memory on startup."""
70
+ for path in SESSION_DIR.glob("*.json"):
71
+ sid = path.stem
72
+ if sid not in sessions:
73
+ _load_session(sid)
74
+
75
+
76
+ # Load persisted sessions on module import
77
+ _load_all_sessions()
78
+
79
  # Valid output formats and their descriptions
80
  OUTPUT_FORMATS = {
81
  "executive": "High-level summary for executives and decision-makers",
 
216
  "created_at": datetime.utcnow().isoformat(),
217
  }
218
  sessions[decision.id] = session
219
+ _persist_session(decision.id, session)
220
 
221
  logger.info(
222
  "Session created: %s — %s [format=%s, domain=%s]",
 
272
  @router.get("/{session_id}/status")
273
  async def get_status(session_id: str):
274
  """Get the current status of a verdict session."""
275
+ session = _get_session(session_id)
276
  if not session:
277
  raise HTTPException(status_code=404, detail="Session not found")
278
 
 
287
  @router.get("/{session_id}/result")
288
  async def get_result(session_id: str):
289
  """Get the complete result of a verdict session."""
290
+ session = _get_session(session_id)
291
  if not session:
292
  raise HTTPException(status_code=404, detail="Session not found")
293
 
 
311
  await websocket.accept()
312
  logger.info("WebSocket connected: %s", session_id)
313
 
314
+ session = _get_session(session_id)
315
  if not session:
316
  await websocket.send_json({"error": "Session not found"})
317
  await websocket.close()
 
352
  }
353
 
354
  session["status"] = "complete"
355
+ _persist_session(session_id, session)
356
  except Exception as e:
357
  logger.error("Pipeline failed: %s", str(e))
358
  session["status"] = "error"
 
403
  @router.get("/{session_id}/export/markdown")
404
  async def export_markdown(session_id: str):
405
  """Export the verdict session as a markdown report."""
406
+ session = _get_session(session_id)
407
  if not session:
408
  raise HTTPException(status_code=404, detail="Session not found")
409
  if not session.get("result"):
 
416
  @router.get("/{session_id}/export/json")
417
  async def export_json(session_id: str):
418
  """Export the verdict session as structured JSON."""
419
+ session = _get_session(session_id)
420
  if not session:
421
  raise HTTPException(status_code=404, detail="Session not found")
422
  if not session.get("result"):
 
429
  @router.get("/{session_id}/export/pdf")
430
  async def export_pdf(session_id: str):
431
  """Export the verdict session as a formatted PDF report."""
432
+ session = _get_session(session_id)
433
  if not session:
434
  raise HTTPException(status_code=404, detail="Session not found")
435
  if not session.get("result"):
 
446
  @router.get("/{session_id}/export/docx")
447
  async def export_docx(session_id: str):
448
  """Export the verdict session as a formatted DOCX report."""
449
+ session = _get_session(session_id)
450
  if not session:
451
  raise HTTPException(status_code=404, detail="Session not found")
452
  if not session.get("result"):
 
472
  @router.post("/{session_id}/followup")
473
  async def followup_question(session_id: str, request: FollowUpRequest):
474
  """Ask a follow-up question about the verdict session."""
475
+ session = _get_session(session_id)
476
  if not session:
477
  raise HTTPException(status_code=404, detail="Session not found")
478
  if not session.get("result"):
 
522
  return {"answer": response.content, "session_id": session_id}
523
  except Exception as e:
524
  raise HTTPException(status_code=500, detail=f"Failed to generate follow-up: {str(e)}")
525
+
526
+
527
+ # ---------------------------------------------------------------------------
528
+ # Verdict sharing endpoint
529
+ # ---------------------------------------------------------------------------
530
+
531
+
532
+ def _generate_share_token(session_id: str) -> str:
533
+ """Generate a short, URL-safe share token from a session ID."""
534
+ return hashlib.sha256(session_id.encode()).hexdigest()[:12]
535
+
536
+
537
+ @router.get("/{session_id}/share")
538
+ async def create_share_link(session_id: str):
539
+ """Generate a shareable link for a completed verdict session.
540
+
541
+ Returns a short share token that can be used to retrieve the session
542
+ results without needing the full session ID.
543
+ """
544
+ session = _get_session(session_id)
545
+ if not session:
546
+ raise HTTPException(status_code=404, detail="Session not found")
547
+ if not session.get("result"):
548
+ raise HTTPException(status_code=202, detail="Session not complete")
549
+
550
+ token = _generate_share_token(session_id)
551
+
552
+ # Store mapping from share token to session ID
553
+ session["share_token"] = token
554
+ _persist_session(session_id, session)
555
+
556
+ return {
557
+ "share_token": token,
558
+ "share_url": f"/shared/{token}",
559
+ "session_id": session_id,
560
+ }
561
+
562
+
563
+ @router.get("/shared/{share_token}")
564
+ async def get_shared_verdict(share_token: str):
565
+ """Retrieve a verdict session via its share token.
566
+
567
+ This allows anyone with the share link to view the verdict results
568
+ without needing the original session ID.
569
+ """
570
+ # Search for session with matching share token
571
+ for sid, session in sessions.items():
572
+ if session.get("share_token") == share_token:
573
+ if not session.get("result"):
574
+ raise HTTPException(status_code=202, detail="Session not complete")
575
+ return {
576
+ "session_id": sid,
577
+ "question": session["decision"].get("question", ""),
578
+ "domain": session.get("domain", "business"),
579
+ "output_format": session.get("output_format", "executive"),
580
+ "result": session["result"],
581
+ }
582
+
583
+ # Also check persisted sessions on disk
584
+ for path in SESSION_DIR.glob("*.json"):
585
+ try:
586
+ with open(path) as f:
587
+ session = json.load(f)
588
+ if session.get("share_token") == share_token:
589
+ return {
590
+ "session_id": path.stem,
591
+ "question": session.get("decision", {}).get("question", ""),
592
+ "domain": session.get("domain", "business"),
593
+ "output_format": session.get("output_format", "executive"),
594
+ "result": session.get("result"),
595
+ }
596
+ except Exception:
597
+ continue
598
+
599
+ raise HTTPException(status_code=404, detail="Shared verdict not found")