MukulRay commited on
Commit
2a79143
Β·
1 Parent(s): e62400d

Phase 2: state schema, SQLite session memory, data dir

Browse files
.env.example ADDED
File without changes
.gitignore CHANGED
@@ -205,3 +205,6 @@ cython_debug/
205
  marimo/_static/
206
  marimo/_lsp/
207
  __marimo__/
 
 
 
 
205
  marimo/_static/
206
  marimo/_lsp/
207
  __marimo__/
208
+
209
+ data/sessions.db
210
+ *.db
app.py ADDED
File without changes
data/.gitkeep ADDED
File without changes
eval/calibration.py ADDED
File without changes
eval/contradiction_viz.py ADDED
File without changes
eval/run_eval.py ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langgraph>=0.2.0
2
+ langchain>=0.3.0
3
+ langchain-groq
4
+ langchain-huggingface
5
+ sentence-transformers
6
+ semanticscholar
7
+ duckduckgo-search
8
+ tavily-python
9
+ ragas
10
+ datasets
11
+ gradio>=4.0.0
12
+ python-dotenv
13
+ pandas
14
+ tqdm
15
+ networkx
16
+ matplotlib
17
+ fastapi
18
+ uvicorn
19
+
src/__init__.py ADDED
File without changes
src/agents/__init__.py ADDED
File without changes
src/agents/critic.py ADDED
File without changes
src/agents/planner.py ADDED
File without changes
src/agents/retriever.py ADDED
File without changes
src/agents/synthesizer.py ADDED
File without changes
src/graph.py ADDED
File without changes
src/memory.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import json
3
+ import os
4
+ from datetime import datetime
5
+ from src.state import SessionContext, SessionUpdate, Claim
6
+
7
+
8
+ _DEFAULT_DB = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data", "sessions.db")
9
+ DB_PATH = os.environ.get("SESSION_DB_PATH", _DEFAULT_DB)
10
+
11
+
12
+ def _get_conn() -> sqlite3.Connection:
13
+ conn = sqlite3.connect(DB_PATH)
14
+ conn.row_factory = sqlite3.Row
15
+ return conn
16
+
17
+
18
+ def init_db() -> None:
19
+ """Create tables if they don't exist. Call once at app startup."""
20
+ with _get_conn() as conn:
21
+ conn.execute("""
22
+ CREATE TABLE IF NOT EXISTS sessions (
23
+ session_id TEXT PRIMARY KEY,
24
+ created_at TEXT NOT NULL,
25
+ updated_at TEXT NOT NULL
26
+ )
27
+ """)
28
+ conn.execute("""
29
+ CREATE TABLE IF NOT EXISTS session_turns (
30
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
31
+ session_id TEXT NOT NULL,
32
+ query TEXT NOT NULL,
33
+ position TEXT NOT NULL,
34
+ claim_json TEXT NOT NULL, -- JSON list of Claim dicts
35
+ contradictions TEXT NOT NULL, -- JSON list of strings
36
+ created_at TEXT NOT NULL,
37
+ FOREIGN KEY (session_id) REFERENCES sessions(session_id)
38
+ )
39
+ """)
40
+ conn.commit()
41
+
42
+
43
+ def load_session(session_id: str) -> SessionContext:
44
+ """
45
+ Load prior positions and contradictions for this session.
46
+ Returns an empty SessionContext if the session doesn't exist yet.
47
+ """
48
+ with _get_conn() as conn:
49
+ rows = conn.execute(
50
+ """
51
+ SELECT query, position, contradictions
52
+ FROM session_turns
53
+ WHERE session_id = ?
54
+ ORDER BY created_at ASC
55
+ """,
56
+ (session_id,)
57
+ ).fetchall()
58
+
59
+ if not rows:
60
+ return SessionContext()
61
+
62
+ prior_positions = [r["position"] for r in rows]
63
+ prior_queries = [r["query"] for r in rows]
64
+ flagged_contradictions = []
65
+ for r in rows:
66
+ flagged_contradictions.extend(json.loads(r["contradictions"]))
67
+
68
+ return SessionContext(
69
+ prior_positions=prior_positions,
70
+ prior_queries=prior_queries,
71
+ flagged_contradictions=flagged_contradictions,
72
+ )
73
+
74
+
75
+ def save_turn(session_id: str, update: SessionUpdate) -> None:
76
+ """
77
+ Persist one completed turn (query + synthesized position + claims).
78
+ Creates the session row if it doesn't exist.
79
+ """
80
+ now = datetime.utcnow().isoformat()
81
+
82
+ claim_json = json.dumps([
83
+ {
84
+ "text": c.text,
85
+ "source_title": c.source_title,
86
+ "source_year": c.source_year,
87
+ "confidence": c.confidence,
88
+ "flagged": c.flagged,
89
+ }
90
+ for c in update.claim_confidences
91
+ ])
92
+
93
+ contradictions_json = json.dumps(update.contradictions_found)
94
+
95
+ with _get_conn() as conn:
96
+ # Upsert the session header row
97
+ conn.execute(
98
+ """
99
+ INSERT INTO sessions (session_id, created_at, updated_at)
100
+ VALUES (?, ?, ?)
101
+ ON CONFLICT(session_id) DO UPDATE SET updated_at = excluded.updated_at
102
+ """,
103
+ (session_id, now, now),
104
+ )
105
+ # Insert the turn
106
+ conn.execute(
107
+ """
108
+ INSERT INTO session_turns
109
+ (session_id, query, position, claim_json, contradictions, created_at)
110
+ VALUES (?, ?, ?, ?, ?, ?)
111
+ """,
112
+ (session_id, update.query, update.position,
113
+ claim_json, contradictions_json, now),
114
+ )
115
+ conn.commit()
116
+
117
+
118
+ def export_session_md(session_id: str) -> str:
119
+ """
120
+ Export the full session as a markdown research note.
121
+ Returns the markdown string (not saved to disk here β€” caller decides).
122
+ """
123
+ with _get_conn() as conn:
124
+ session_row = conn.execute(
125
+ "SELECT created_at FROM sessions WHERE session_id = ?",
126
+ (session_id,)
127
+ ).fetchone()
128
+
129
+ turns = conn.execute(
130
+ """
131
+ SELECT query, position, claim_json, contradictions, created_at
132
+ FROM session_turns
133
+ WHERE session_id = ?
134
+ ORDER BY created_at ASC
135
+ """,
136
+ (session_id,)
137
+ ).fetchall()
138
+
139
+ if not session_row:
140
+ return "# Session not found\n"
141
+
142
+ lines = [
143
+ f"# RECON Research Session",
144
+ f"**Session ID:** `{session_id}` ",
145
+ f"**Started:** {session_row['created_at']} ",
146
+ f"**Turns:** {len(turns)}",
147
+ "",
148
+ "---",
149
+ "",
150
+ ]
151
+
152
+ for i, turn in enumerate(turns, 1):
153
+ claims = json.loads(turn["claim_json"])
154
+ contradictions = json.loads(turn["contradictions"])
155
+
156
+ lines += [
157
+ f"## Turn {i}: {turn['query']}",
158
+ "",
159
+ "### Position",
160
+ turn["position"],
161
+ "",
162
+ ]
163
+
164
+ if claims:
165
+ lines += ["### Claims", ""]
166
+ for c in claims:
167
+ flag = " ⚠️" if c["flagged"] else ""
168
+ lines.append(
169
+ f"- **[{c['confidence'].upper()}]** {c['text']} "
170
+ f"β€” *{c['source_title']} ({c['source_year']})*{flag}"
171
+ )
172
+ lines.append("")
173
+
174
+ if contradictions:
175
+ lines += ["### Contradictions flagged", ""]
176
+ for contradiction in contradictions:
177
+ lines.append(f"- {contradiction}")
178
+ lines.append("")
179
+
180
+ lines.append("---")
181
+ lines.append("")
182
+
183
+ return "\n".join(lines)
184
+
185
+
186
+ def delete_session(session_id: str) -> None:
187
+ """Hard delete a session and all its turns."""
188
+ with _get_conn() as conn:
189
+ conn.execute("DELETE FROM session_turns WHERE session_id = ?", (session_id,))
190
+ conn.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,))
191
+ conn.commit()
src/state.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict, Optional
2
+ from dataclasses import dataclass, field
3
+
4
+
5
+ # ---------------------------------------------------------------------------
6
+ # Dataclasses β€” typed objects passed through the graph
7
+ # ---------------------------------------------------------------------------
8
+
9
+ @dataclass
10
+ class Paper:
11
+ title: str
12
+ abstract: str
13
+ year: int
14
+ citation_count: int
15
+ paper_id: str
16
+ authors: list[str] = field(default_factory=list)
17
+ references: list[str] = field(default_factory=list) # list of paper_ids
18
+ hybrid_score: float = 0.0
19
+ source: str = "semantic_scholar" # or "web"
20
+
21
+
22
+ @dataclass
23
+ class WebResult:
24
+ url: str
25
+ snippet: str
26
+ title: str
27
+ inferred_year: Optional[int] = None
28
+ hybrid_score: float = 0.0
29
+ source: str = "web"
30
+
31
+
32
+ @dataclass
33
+ class Claim:
34
+ text: str
35
+ source_title: str
36
+ source_year: int
37
+ confidence: str # "high" | "medium" | "low"
38
+ flagged: bool = False # True if contested or contradicted
39
+
40
+
41
+ @dataclass
42
+ class SessionContext:
43
+ prior_positions: list[str] = field(default_factory=list)
44
+ flagged_contradictions: list[str] = field(default_factory=list)
45
+ prior_queries: list[str] = field(default_factory=list)
46
+
47
+
48
+ @dataclass
49
+ class SessionUpdate:
50
+ position: str
51
+ query: str
52
+ claim_confidences: list[Claim] = field(default_factory=list)
53
+ contradictions_found: list[str] = field(default_factory=list)
54
+
55
+
56
+ # ---------------------------------------------------------------------------
57
+ # Verdict constants β€” used by Critic agent
58
+ # ---------------------------------------------------------------------------
59
+
60
+ class Verdict:
61
+ PASS = "PASS"
62
+ STALE = "STALE"
63
+ CONTRADICTED = "CONTRADICTED"
64
+ INSUFFICIENT = "INSUFFICIENT"
65
+ FORCED_PASS = "FORCED_PASS"
66
+
67
+
68
+ # ---------------------------------------------------------------------------
69
+ # LangGraph state β€” the single TypedDict shared across all agents
70
+ # ---------------------------------------------------------------------------
71
+
72
+ class ResearchState(TypedDict):
73
+ # --- Input ---
74
+ original_query: str
75
+ session_id: str
76
+
77
+ # --- Planner output ---
78
+ session_context: Optional[SessionContext]
79
+ sub_questions: list[str]
80
+
81
+ # --- Retriever output ---
82
+ retrieved_papers: list[Paper]
83
+ citation_graph: dict # {paper_id: [cited_paper_ids]}
84
+ web_results: list[WebResult]
85
+
86
+ # --- Critic output ---
87
+ critic_verdict: str # one of Verdict constants
88
+ critic_notes: str
89
+ rewritten_questions: list[str]
90
+ retry_count: int
91
+
92
+ # --- Synthesizer output ---
93
+ synthesized_position: str
94
+ claim_confidences: list[Claim]
95
+ session_update: Optional[SessionUpdate]
96
+ export_md: str # NEW v2 β€” full session as markdown
97
+
98
+ # --- Eval / config ---
99
+ decay_config: str # "none" | "linear" | "log"
100
+ calibration_bin: str # filled by critic for eval aggregation
101
+ latency_ms: float
test_phase2.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.insert(0, ".")
3
+
4
+ from src.state import ResearchState, Paper, Claim, Verdict, SessionContext
5
+ from src.memory import init_db, load_session, save_turn, export_session_md, delete_session
6
+ from src.state import SessionUpdate
7
+ import uuid
8
+
9
+ print("=== Phase 2: State + Memory ===")
10
+
11
+ # Test 1: State dataclasses
12
+ p = Paper(
13
+ title="Test Paper",
14
+ abstract="This is a test abstract.",
15
+ year=2024,
16
+ citation_count=100,
17
+ paper_id="abc123"
18
+ )
19
+ print(f"βœ“ Paper dataclass: {p.title} ({p.year})")
20
+
21
+ c = Claim(text="Test claim", source_title="Test Paper", source_year=2024, confidence="high")
22
+ print(f"βœ“ Claim dataclass: [{c.confidence}] {c.text}")
23
+
24
+ print(f"βœ“ Verdict constants: {Verdict.PASS} / {Verdict.STALE} / {Verdict.CONTRADICTED}")
25
+
26
+ # Test 2: SQLite memory
27
+ init_db()
28
+ print("βœ“ Database initialized")
29
+
30
+ session_id = str(uuid.uuid4())
31
+
32
+ # Load empty session
33
+ ctx = load_session(session_id)
34
+ print(f"βœ“ Empty session loaded: {len(ctx.prior_positions)} prior positions")
35
+
36
+ # Save a turn
37
+ update = SessionUpdate(
38
+ query="What is the state of KV cache compression?",
39
+ position="KV cache compression has advanced significantly with methods like H2O and StreamingLLM.",
40
+ claim_confidences=[
41
+ Claim("H2O reduces KV cache size by 20x", "H2O Paper", 2023, "high"),
42
+ Claim("StreamingLLM enables infinite context", "StreamingLLM", 2023, "medium"),
43
+ ],
44
+ contradictions_found=["StreamingLLM contradicted by later infinite attention work (2024)"]
45
+ )
46
+ save_turn(session_id, update)
47
+ print("βœ“ Turn saved to database")
48
+
49
+ # Reload and verify
50
+ ctx2 = load_session(session_id)
51
+ print(f"βœ“ Session reloaded: {len(ctx2.prior_positions)} prior position(s)")
52
+ print(f" Prior query: {ctx2.prior_queries[0][:60]}...")
53
+
54
+ # Export markdown
55
+ md = export_session_md(session_id)
56
+ print(f"βœ“ Markdown export: {len(md)} characters")
57
+ print(f" Preview: {md[:120].strip()}")
58
+
59
+ # Cleanup
60
+ delete_session(session_id)
61
+ print("βœ“ Session deleted")
62
+
63
+ print("\nβœ… Phase 2 complete")
test_setup.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+
4
+ load_dotenv()
5
+
6
+ print("=== Testing imports ===")
7
+ import langgraph; print("βœ“ langgraph")
8
+ import langchain; print(f"βœ“ langchain {langchain.__version__}")
9
+ import gradio; print(f"βœ“ gradio {gradio.__version__}")
10
+ import semanticscholar; print("βœ“ semanticscholar")
11
+ from sentence_transformers import SentenceTransformer; print("βœ“ sentence-transformers")
12
+ import networkx; print(f"βœ“ networkx {networkx.__version__}")
13
+
14
+ print("\n=== Testing API keys ===")
15
+ groq_key = os.getenv("GROQ_API_KEY")
16
+ s2_key = os.getenv("S2_API_KEY")
17
+ tavily_key = os.getenv("TAVILY_API_KEY")
18
+ print(f"βœ“ GROQ_API_KEY: {'set' if groq_key else 'MISSING'}")
19
+ print(f"βœ“ S2_API_KEY: {'set β€” will activate in 1-3 days' if s2_key else 'not set yet (pending)' }")
20
+ print(f"βœ“ TAVILY_API_KEY: {'set' if tavily_key else 'MISSING'}")
21
+
22
+ print("\n=== Testing Groq connection ===")
23
+ from langchain_groq import ChatGroq
24
+ llm = ChatGroq(model="llama-3.3-70b-versatile", api_key=groq_key)
25
+ response = llm.invoke("Say exactly: setup confirmed")
26
+ print(f"βœ“ Groq response: {response.content}")
27
+
28
+ print("\nβœ… Phase 1 complete β€” all systems go")