aditya-joshi-05 commited on
Commit
f0d100b
Β·
1 Parent(s): f6803e9

Add phase 3 & 4

Browse files
.env.example CHANGED
@@ -1,13 +1,16 @@
1
  # ── Copy this file to .env and fill in your values ─────────────
2
  # cp .env.example .env
3
 
4
- # ── Groq API ───────────────────────────────────────────────────
5
- GROQ_API_KEY=gsk_your_key_here
 
 
 
6
 
7
  # ── Milvus (defaults work with docker-compose) ─────────────────
8
- MILVUS_HOST=localhost
9
- MILVUS_PORT=19530
10
- MILVUS_COLLECTION=cortex_chunks
11
 
12
  # ── Embedding model ────────────────────────────────────────────
13
  EMBED_MODEL_NAME=BAAI/bge-small-en-v1.5
@@ -23,9 +26,35 @@ RETRIEVAL_TOP_K=15
23
  FINAL_TOP_K=5
24
 
25
  # ── LLM ────────────────────────────────────────────────────────
26
- GROQ_MODEL=llama-3.3-70b-versatile
27
  GROQ_TEMPERATURE=0.1
28
  GROQ_MAX_TOKENS=1024
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # ── Logging ────────────────────────────────────────────────────
31
  LOG_LEVEL=INFO
 
1
  # ── Copy this file to .env and fill in your values ─────────────
2
  # cp .env.example .env
3
 
4
+ # ── API KEYS ───────────────────────────────────────────────────
5
+ GROQ_API_KEY=
6
+ NVIDIA_API_KEY=
7
+ MISTRAL_API_KEY=
8
+ TAVILY_API_KEY=
9
 
10
  # ── Milvus (defaults work with docker-compose) ─────────────────
11
+ MILVUS_HOST=
12
+ MILVUS_PORT=
13
+ MILVUS_COLLECTION=
14
 
15
  # ── Embedding model ────────────────────────────────────────────
16
  EMBED_MODEL_NAME=BAAI/bge-small-en-v1.5
 
26
  FINAL_TOP_K=5
27
 
28
  # ── LLM ────────────────────────────────────────────────────────
29
+ GROQ_MODEL=openai/gpt-oss-120b
30
  GROQ_TEMPERATURE=0.1
31
  GROQ_MAX_TOKENS=1024
32
 
33
+ # ── Knowledge graph ────────────────────────────────────────────
34
+ # rebel β†’ local REBEL model, no API calls (~1.6GB download on first run)
35
+ # llm β†’ Groq LLM, free-form predicates (rate-limited)
36
+ # rebel-filtered β†’ REBEL + entity density pre-filter: skips ~70% of chunks
37
+ # llm-filtered β†’ LLM + entity density pre-filter: drastically fewer API calls
38
+ GRAPH_EXTRACTOR=llm-filtered
39
+ REBEL_BATCH_SIZE=4 # lower to 4 if you hit OOM on CPU
40
+
41
+ # Density filter settings (only used when GRAPH_EXTRACTOR ends with -filtered)
42
+ DENSITY_TOP_FRACTION=0.30 # process top 30% most entity-rich chunks
43
+ DENSITY_MIN_ENTITIES=2 # hard floor: always skip chunks with fewer than N entities
44
+
45
+ # ── RE LLM (LLM accessible via Mistral or Ollama) ────────────────────────────────────────────────────────
46
+ LLM_SERVER=mistral # options: mistral, ollama
47
+ MISTRAL_MODEL=devstral-latest
48
+ OLLAMA_MODEL=llama3.2:3b
49
+ OLLAMA_HOST=
50
+
51
+ # ── Redis cache ────────────────────────────────────────────────
52
+ REDIS_URL=
53
+ CACHE_TTL_SECONDS=3600
54
+
55
+ # ── Evaluation ─────────────────────────────────────────────────
56
+ EVAL_DB_PATH=
57
+ EVAL_ENABLED=true # set false to skip RAGAS LLM calls
58
+
59
  # ── Logging ────────────────────────────────────────────────────
60
  LOG_LEVEL=INFO
api/main.py CHANGED
@@ -39,6 +39,10 @@ from api.schemas import (
39
  )
40
  from config import get_settings
41
  from generation.generator import Generator, GenerationRequest
 
 
 
 
42
  from ingestion.pipeline import IngestionPipeline
43
  from retrieval.dense import MilvusStore
44
  from retrieval.embedder import Embedder
@@ -46,6 +50,7 @@ from retrieval.bm25 import BM25Retriever
46
  from retrieval.orchestrator import MultiStrategyRetriever
47
 
48
  logger = logging.getLogger(__name__)
 
49
 
50
  # ── Shared singletons ──────────────────────────────────────────
51
  # Created once on startup, shared across requests
@@ -54,6 +59,9 @@ _embedder: Embedder = None
54
  _store: MilvusStore = None
55
  _bm25: BM25Retriever = None
56
  _retriever: MultiStrategyRetriever = None
 
 
 
57
  _generator: Generator = None
58
  _pipeline: IngestionPipeline = None
59
 
@@ -61,14 +69,19 @@ _pipeline: IngestionPipeline = None
61
  @asynccontextmanager
62
  async def lifespan(app: FastAPI):
63
  """Initialise shared resources on startup, clean up on shutdown."""
64
- global _embedder, _store, _bm25, _retriever, _generator, _pipeline
65
  logger.info("Cortex starting up...")
66
 
67
  _embedder = Embedder()
68
  _store = MilvusStore(embedder=_embedder)
69
  _bm25 = BM25Retriever()
70
  _retriever = MultiStrategyRetriever(embedder=_embedder, store=_store, bm25=_bm25)
71
- _generator = Generator()
 
 
 
 
 
72
  _pipeline = IngestionPipeline(embedder=_embedder, store=_store, bm25=_bm25)
73
 
74
  # Warm up: trigger model load immediately so first request is fast
@@ -126,11 +139,18 @@ async def health() -> HealthResponse:
126
 
127
  embedder_status = "loaded" if _embedder and _embedder._model else "not_loaded"
128
 
 
 
 
 
 
 
129
  return HealthResponse(
130
  status="ok" if milvus_status == "ok" else "degraded",
131
  milvus=milvus_status,
132
  embedder=embedder_status,
133
  collection_stats=collection_stats,
 
134
  )
135
 
136
 
@@ -159,6 +179,25 @@ async def ingest(req: IngestRequest) -> IngestResponse:
159
 
160
  return IngestResponse(**stats)
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  @app.post("/query", response_model=QueryResponse, tags=["retrieval"])
164
  async def query(req: QueryRequest) -> QueryResponse:
@@ -169,6 +208,9 @@ async def query(req: QueryRequest) -> QueryResponse:
169
  cfg = get_settings()
170
  k = req.top_k or cfg.retrieval_top_k
171
 
 
 
 
172
  try:
173
  retrieval = _retriever.retrieve(req.query, top_k_candidates=k, final_top_k=cfg.final_top_k)
174
  except Exception as exc:
@@ -187,6 +229,14 @@ async def query(req: QueryRequest) -> QueryResponse:
187
 
188
  final_chunks = retrieval.chunks
189
 
 
 
 
 
 
 
 
 
190
  try:
191
  result = _generator.generate(
192
  GenerationRequest(query=req.query, chunks=final_chunks)
@@ -195,6 +245,30 @@ async def query(req: QueryRequest) -> QueryResponse:
195
  logger.exception("Generation error")
196
  raise HTTPException(status_code=500, detail=f"Generation failed: {exc}")
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  return QueryResponse(
199
  query=req.query,
200
  answer=result.answer,
@@ -277,7 +351,25 @@ async def query_stream(req: QueryRequest):
277
  yield _sse_event({"type": "done"})
278
  return
279
 
280
- # 3. Stream answer tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  gen_request = GenerationRequest(
282
  query=req.query,
283
  chunks=final_chunks,
 
39
  )
40
  from config import get_settings
41
  from generation.generator import Generator, GenerationRequest
42
+ from generation.crag import CRAGGate
43
+ from evaluation.store import EvalStore, QueryLogEntry
44
+ from evaluation.ragas_eval import RAGASEvaluator, EvalInput
45
+ from retrieval.cache import CachedRetriever
46
  from ingestion.pipeline import IngestionPipeline
47
  from retrieval.dense import MilvusStore
48
  from retrieval.embedder import Embedder
 
50
  from retrieval.orchestrator import MultiStrategyRetriever
51
 
52
  logger = logging.getLogger(__name__)
53
+ cfg = get_settings()
54
 
55
  # ── Shared singletons ──────────────────────────────────────────
56
  # Created once on startup, shared across requests
 
59
  _store: MilvusStore = None
60
  _bm25: BM25Retriever = None
61
  _retriever: MultiStrategyRetriever = None
62
+ _crag: CRAGGate = None
63
+ _eval_store: EvalStore = None
64
+ _evaluator: RAGASEvaluator = None
65
  _generator: Generator = None
66
  _pipeline: IngestionPipeline = None
67
 
 
69
  @asynccontextmanager
70
  async def lifespan(app: FastAPI):
71
  """Initialise shared resources on startup, clean up on shutdown."""
72
+ global _embedder, _store, _bm25, _retriever, _crag, _generator, _pipeline, _eval_store, _evaluator
73
  logger.info("Cortex starting up...")
74
 
75
  _embedder = Embedder()
76
  _store = MilvusStore(embedder=_embedder)
77
  _bm25 = BM25Retriever()
78
  _retriever = MultiStrategyRetriever(embedder=_embedder, store=_store, bm25=_bm25)
79
+ _crag = CRAGGate()
80
+ _eval_store = EvalStore(db_path=cfg.eval_db_path)
81
+ _evaluator = RAGASEvaluator(store=_eval_store)
82
+ _generator = Generator()
83
+ # Wrap retriever with Redis cache (degrades gracefully if Redis is absent)
84
+ _retriever = CachedRetriever(_retriever)
85
  _pipeline = IngestionPipeline(embedder=_embedder, store=_store, bm25=_bm25)
86
 
87
  # Warm up: trigger model load immediately so first request is fast
 
139
 
140
  embedder_status = "loaded" if _embedder and _embedder._model else "not_loaded"
141
 
142
+ graph_stats = {}
143
+ try:
144
+ graph_stats = _retriever.graph_builder.stats()
145
+ except Exception:
146
+ pass
147
+
148
  return HealthResponse(
149
  status="ok" if milvus_status == "ok" else "degraded",
150
  milvus=milvus_status,
151
  embedder=embedder_status,
152
  collection_stats=collection_stats,
153
+ graph_stats=graph_stats,
154
  )
155
 
156
 
 
179
 
180
  return IngestResponse(**stats)
181
 
182
+ @app.get("/metrics", tags=["evaluation"])
183
+ async def get_metrics(limit: int = 100, days: int = 7):
184
+ """
185
+ Query performance metrics and RAGAS scores for the dashboard.
186
+ Returns summary stats, recent query logs, and hourly timeseries.
187
+ """
188
+ return {
189
+ "summary": _eval_store.get_summary_stats(),
190
+ "recent": _eval_store.get_recent_queries(limit=limit),
191
+ "timeseries": _eval_store.get_metric_timeseries(days=days),
192
+ "cache": _retriever.cache_stats(),
193
+ }
194
+
195
+
196
+ @app.post("/cache/flush", tags=["system"])
197
+ async def flush_cache():
198
+ """Flush all Redis retrieval cache entries."""
199
+ deleted = _retriever.flush_all()
200
+ return {"deleted": deleted}
201
 
202
  @app.post("/query", response_model=QueryResponse, tags=["retrieval"])
203
  async def query(req: QueryRequest) -> QueryResponse:
 
208
  cfg = get_settings()
209
  k = req.top_k or cfg.retrieval_top_k
210
 
211
+ import time as _time
212
+ _t0 = _time.perf_counter()
213
+
214
  try:
215
  retrieval = _retriever.retrieve(req.query, top_k_candidates=k, final_top_k=cfg.final_top_k)
216
  except Exception as exc:
 
229
 
230
  final_chunks = retrieval.chunks
231
 
232
+ # CRAG gate: grade, rewrite if POOR, web-search fallback if ABSENT
233
+ crag_result = _crag.evaluate(
234
+ query=req.query,
235
+ chunks=final_chunks,
236
+ retriever_fn=lambda q: _retriever.retrieve(q).chunks,
237
+ )
238
+ final_chunks = crag_result.final_chunks
239
+
240
  try:
241
  result = _generator.generate(
242
  GenerationRequest(query=req.query, chunks=final_chunks)
 
245
  logger.exception("Generation error")
246
  raise HTTPException(status_code=500, detail=f"Generation failed: {exc}")
247
 
248
+ latency_ms = (_time.perf_counter() - _t0) * 1000
249
+
250
+ log_id = _eval_store.log_query(QueryLogEntry(
251
+ query=req.query,
252
+ intent=retrieval.decision.intent.value,
253
+ strategies=retrieval.decision.strategies,
254
+ retriever_hits=retrieval.retriever_hits,
255
+ crag_grade=crag_result.grade.value,
256
+ crag_rewritten=bool(crag_result.rewritten_query),
257
+ web_search_used=crag_result.web_search_used,
258
+ num_chunks=len(final_chunks),
259
+ top_chunk_score=final_chunks[0].score if final_chunks else 0.0,
260
+ latency_ms=latency_ms,
261
+ model=result.model,
262
+ ))
263
+
264
+ if cfg.eval_enabled:
265
+ _evaluator.evaluate_async(EvalInput(
266
+ query_log_id=log_id,
267
+ query=req.query,
268
+ answer=result.answer,
269
+ chunks=final_chunks,
270
+ ))
271
+
272
  return QueryResponse(
273
  query=req.query,
274
  answer=result.answer,
 
351
  yield _sse_event({"type": "done"})
352
  return
353
 
354
+ # 3. CRAG gate β€” grade, optionally rewrite + re-retrieve
355
+ crag_result = _crag.evaluate(
356
+ query=req.query,
357
+ chunks=final_chunks,
358
+ retriever_fn=lambda q: _retriever.retrieve(q).chunks,
359
+ )
360
+ final_chunks = crag_result.final_chunks
361
+
362
+ # Emit CRAG event if something interesting happened
363
+ if crag_result.grade.value != "GOOD" or crag_result.web_search_used:
364
+ yield _sse_event({
365
+ "type": "crag_update",
366
+ "grade": crag_result.grade.value,
367
+ "rewritten_query": crag_result.rewritten_query,
368
+ "web_search_used": crag_result.web_search_used,
369
+ "reasoning": crag_result.reasoning,
370
+ })
371
+
372
+ # 4. Stream answer tokens
373
  gen_request = GenerationRequest(
374
  query=req.query,
375
  chunks=final_chunks,
api/schemas.py CHANGED
@@ -45,6 +45,9 @@ class QueryResponse(BaseModel):
45
  citations: list[CitationResponse]
46
  retrieved_chunks: list[ChunkResponse]
47
  routing: Optional[RoutingResponse] = None
 
 
 
48
  model: str
49
  usage: dict
50
 
@@ -60,6 +63,8 @@ class IngestResponse(BaseModel):
60
  chunks_created: int
61
  chunks_stored: int
62
  bm25_indexed: int = 0
 
 
63
  errors: list[dict] = []
64
 
65
 
@@ -68,3 +73,4 @@ class HealthResponse(BaseModel):
68
  milvus: str
69
  embedder: str
70
  collection_stats: dict
 
 
45
  citations: list[CitationResponse]
46
  retrieved_chunks: list[ChunkResponse]
47
  routing: Optional[RoutingResponse] = None
48
+ crag_grade: Optional[str] = None
49
+ crag_rewritten_query: Optional[str] = None
50
+ web_search_used: bool = False
51
  model: str
52
  usage: dict
53
 
 
63
  chunks_created: int
64
  chunks_stored: int
65
  bm25_indexed: int = 0
66
+ graph_entities: int = 0
67
+ graph_triples: int = 0
68
  errors: list[dict] = []
69
 
70
 
 
73
  milvus: str
74
  embedder: str
75
  collection_stats: dict
76
+ graph_stats: dict = {}
config.py CHANGED
@@ -40,12 +40,15 @@ class Settings(BaseSettings):
40
  retrieval_top_k: int = 15 # candidates before reranking
41
  final_top_k: int = 5 # chunks sent to LLM
42
 
43
- # ── LLM / Groq ───────────────────────────────────────────
44
  groq_api_key: str = os.getenv("GROQ_API_KEY", "") # must be set in .env for LLM classification to work
45
  groq_model: str = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
46
  groq_temperature: float = float(os.getenv("GROQ_TEMPERATURE", 0.1))
47
  groq_max_tokens: int = int(os.getenv("GROQ_MAX_TOKENS", 1024))
48
  groq_timeout: int = int(os.getenv("GROQ_TIMEOUT", 30)) # seconds before Groq client timeout
 
 
 
49
 
50
  # ── FastAPI ──────────────────────────────────────────────
51
  api_host: str = "0.0.0.0"
@@ -56,6 +59,39 @@ class Settings(BaseSettings):
56
  data_dir: str = "data/documents"
57
  log_level: str = "INFO"
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  @lru_cache(maxsize=1)
61
  def get_settings() -> Settings:
 
40
  retrieval_top_k: int = 15 # candidates before reranking
41
  final_top_k: int = 5 # chunks sent to LLM
42
 
43
+ # ── LLM / TAVILY ───────────────────────────────────────────
44
  groq_api_key: str = os.getenv("GROQ_API_KEY", "") # must be set in .env for LLM classification to work
45
  groq_model: str = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
46
  groq_temperature: float = float(os.getenv("GROQ_TEMPERATURE", 0.1))
47
  groq_max_tokens: int = int(os.getenv("GROQ_MAX_TOKENS", 1024))
48
  groq_timeout: int = int(os.getenv("GROQ_TIMEOUT", 30)) # seconds before Groq client timeout
49
+ tavily_api_key: str = os.getenv("TAVILY_API_KEY", "")
50
+ mistral_api_key: str = os.getenv("MISTRAL_API_KEY", "")
51
+ mistral_model: str = os.getenv("MISTRAL_MODEL", "devstral-latest")
52
 
53
  # ── FastAPI ──────────────────────────────────────────────
54
  api_host: str = "0.0.0.0"
 
59
  data_dir: str = "data/documents"
60
  log_level: str = "INFO"
61
 
62
+ # ── CRAG ─────────────────────────────────────────────────
63
+ crag_enabled: bool = True
64
+ crag_relevance_threshold: float = 0.5 # below this β†’ POOR grade
65
+
66
+ # ── Graph ─────────────────────────────────────────────────
67
+ graph_enabled: bool = True
68
+ graph_path: str = "data/knowledge_graph.json"
69
+ graph_max_hops: int = 2
70
+ # "rebel" β†’ local REBEL model, no API calls (default)
71
+ # "llm" β†’ Groq LLM, free-form predicates
72
+ # "rebel-filtered" β†’ REBEL + entity density pre-filter (option 4)
73
+ # "llm-filtered" β†’ LLM + entity density pre-filter (option 4)
74
+ graph_extractor: str = "llm-filtered"
75
+ rebel_batch_size: int = 4 # chunks per REBEL forward pass; lower if OOM
76
+
77
+ # ── Density filter (used when graph_extractor ends with "-filtered") ──
78
+ density_top_fraction: float = 0.30 # process top 30% most entity-dense chunks
79
+ density_min_entities: int = 2 # hard floor: skip chunks with fewer entities
80
+
81
+ # ── Relation Ext LLM (LLM accessible via Mistral or Ollama) ────────────────────────────────────────────────────────
82
+ llm_server: str = os.getenv("LLM_SERVER", "mistral") # "mistral" or "ollama"
83
+ ollama_model: str = os.getenv("OLLAMA_MODEL", "llama3.2:3b")
84
+ ollama_host: str = os.getenv("OLLAMA_HOST", "") # Ollama server URL
85
+ mistral_model: str = os.getenv("MISTRAL_MODEL", "devstral-latest")
86
+
87
+ # ── Redis cache ───────────────────────────────────────────
88
+ redis_url: str = "redis://localhost:6379"
89
+ cache_ttl_seconds: int = 3600 # 1 hour
90
+
91
+ # ── Evaluation ────────────────────────────────────────────
92
+ eval_db_path: str = "data/cortex_eval.db"
93
+ eval_enabled: bool = True # set False to skip RAGAS calls entirely
94
+
95
 
96
  @lru_cache(maxsize=1)
97
  def get_settings() -> Settings:
evaluation/__init__.py ADDED
File without changes
evaluation/ragas_eval.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cortex RAG β€” RAGAS Evaluation Harness (Phase 4)
3
+
4
+ Why reference-free metrics?
5
+ ────────────────────────────
6
+ Classic RAG evaluation requires ground-truth answers (golden QA pairs).
7
+ We don't have those at runtime. RAGAS provides three metrics that need
8
+ only (question, answer, retrieved_contexts):
9
+
10
+ faithfulness β€” Does the answer make claims supported by the context?
11
+ Computed by asking an LLM to identify each claim in
12
+ the answer, then checking each claim against the context.
13
+ Score = supported_claims / total_claims.
14
+
15
+ answer_relevancy β€” Does the answer actually address the question?
16
+ Computed by generating N hypothetical questions from the
17
+ answer and measuring cosine similarity to the original
18
+ question. Low score = answer talks about something else.
19
+
20
+ context_precision β€” Are the retrieved chunks actually relevant to the query?
21
+ Computed by asking an LLM whether each chunk is useful
22
+ for answering the query. Score = relevant_chunks / total.
23
+
24
+ We also compute two lightweight custom metrics without any LLM calls:
25
+
26
+ context_utilisation β€” What fraction of the retrieved chunks are cited in the
27
+ answer? (Count [1], [2]... citation markers.) A low score
28
+ means the generator ignored most of what was retrieved.
29
+
30
+ mean_chunk_score β€” Average retrieval score (post-reranking) of the final
31
+ chunks. Tracks retrieval quality independently of answer
32
+ quality. Useful for spotting when CRAG rewrites help.
33
+
34
+ Running mode
35
+ ────────────
36
+ Evaluation is async β€” it runs in a background thread after the response
37
+ has been streamed to the user, so it never adds latency to the query path.
38
+ Results are written to the EvalStore (SQLite) and appear in the dashboard.
39
+
40
+ If RAGAS is not installed or the LLM call fails, only the two custom
41
+ metrics (context_utilisation, mean_chunk_score) are computed and stored.
42
+ This ensures the evaluation pipeline never blocks ingestion or queries.
43
+ """
44
+ from __future__ import annotations
45
+
46
+ import logging
47
+ import re
48
+ import threading
49
+ from dataclasses import dataclass, field
50
+ from typing import Optional
51
+
52
+ from evaluation.store import EvalMetricEntry, EvalStore
53
+ from retrieval.dense import RetrievedChunk
54
+
55
+ logger = logging.getLogger(__name__)
56
+
57
+
58
+ @dataclass
59
+ class EvalInput:
60
+ """Everything needed to evaluate one query-response pair."""
61
+ query_log_id: int
62
+ query: str
63
+ answer: str
64
+ chunks: list[RetrievedChunk] = field(default_factory=list)
65
+
66
+
67
+ @dataclass
68
+ class EvalResult:
69
+ faithfulness: Optional[float] = None
70
+ answer_relevancy: Optional[float] = None
71
+ context_precision: Optional[float] = None
72
+ context_utilisation: Optional[float] = None
73
+ mean_chunk_score: Optional[float] = None
74
+
75
+ def as_store_entry(self, query_log_id: int) -> EvalMetricEntry:
76
+ return EvalMetricEntry(
77
+ query_log_id=query_log_id,
78
+ faithfulness=self.faithfulness,
79
+ answer_relevancy=self.answer_relevancy,
80
+ context_precision=self.context_precision,
81
+ context_utilisation=self.context_utilisation,
82
+ mean_chunk_score=self.mean_chunk_score,
83
+ )
84
+
85
+
86
+ class RAGASEvaluator:
87
+ """
88
+ Computes RAGAS + custom metrics for a query-response pair.
89
+
90
+ Usage β€” fire-and-forget (non-blocking):
91
+ evaluator = RAGASEvaluator(store)
92
+ evaluator.evaluate_async(EvalInput(
93
+ query_log_id=log_id,
94
+ query="What is attention?",
95
+ answer="Attention is...",
96
+ chunks=final_chunks,
97
+ ))
98
+
99
+ Usage β€” blocking (for testing):
100
+ result = evaluator.evaluate(eval_input)
101
+ """
102
+
103
+ def __init__(self, store: Optional[EvalStore] = None) -> None:
104
+ self._store = store or EvalStore()
105
+ self._ragas_available = self._check_ragas()
106
+
107
+ # ── Public API ─────────────────────────────────────────────
108
+
109
+ def evaluate_async(self, inp: EvalInput) -> None:
110
+ """
111
+ Run evaluation in a daemon thread. Returns immediately.
112
+ Results are written to EvalStore when complete.
113
+ """
114
+ thread = threading.Thread(
115
+ target=self._run_and_store,
116
+ args=(inp,),
117
+ daemon=True,
118
+ name=f"ragas-eval-{inp.query_log_id}",
119
+ )
120
+ thread.start()
121
+
122
+ def evaluate(self, inp: EvalInput) -> EvalResult:
123
+ """Blocking evaluation. Returns EvalResult."""
124
+ result = EvalResult()
125
+
126
+ # ── Custom metrics (no LLM, always computed) ──────────
127
+ result.context_utilisation = self._context_utilisation(inp.answer, inp.chunks)
128
+ result.mean_chunk_score = self._mean_chunk_score(inp.chunks)
129
+
130
+ # ── RAGAS metrics (LLM-based, may be skipped) ─────────
131
+ if self._ragas_available and inp.chunks:
132
+ ragas_scores = self._run_ragas(inp)
133
+ result.faithfulness = ragas_scores.get("faithfulness")
134
+ result.answer_relevancy = ragas_scores.get("answer_relevancy")
135
+ result.context_precision = ragas_scores.get("context_precision")
136
+ else:
137
+ if not self._ragas_available:
138
+ logger.debug("RAGAS not installed β€” only custom metrics computed.")
139
+
140
+ return result
141
+
142
+ # ── Private ────────────────────────────────────────────────
143
+
144
+ def _run_and_store(self, inp: EvalInput) -> None:
145
+ try:
146
+ result = self.evaluate(inp)
147
+ self._store.log_metrics(result.as_store_entry(inp.query_log_id))
148
+ logger.debug(
149
+ "Eval stored for query %d: faith=%.2f rel=%.2f prec=%.2f util=%.2f",
150
+ inp.query_log_id,
151
+ result.faithfulness or 0,
152
+ result.answer_relevancy or 0,
153
+ result.context_precision or 0,
154
+ result.context_utilisation or 0,
155
+ )
156
+ except Exception as exc:
157
+ logger.warning("Eval failed for query %d: %s", inp.query_log_id, exc)
158
+
159
+ def _run_ragas(self, inp: EvalInput) -> dict:
160
+ """
161
+ Call RAGAS library. Returns dict of metric_name β†’ score.
162
+ Returns empty dict on any failure.
163
+ """
164
+ try:
165
+ from datasets import Dataset # type: ignore
166
+ from ragas import evaluate as ragas_evaluate # type: ignore
167
+ from ragas.metrics import ( # type: ignore
168
+ answer_relevancy,
169
+ context_precision,
170
+ faithfulness,
171
+ )
172
+ from config import get_settings
173
+ cfg = get_settings()
174
+
175
+ # RAGAS expects a HuggingFace Dataset
176
+ data = {
177
+ "question": [inp.query],
178
+ "answer": [inp.answer],
179
+ "contexts": [[c.parent_text or c.text for c in inp.chunks]],
180
+ # reference not available at runtime β€” omit context_recall
181
+ }
182
+ dataset = Dataset.from_dict(data)
183
+
184
+ scores = ragas_evaluate(
185
+ dataset,
186
+ metrics=[faithfulness, answer_relevancy, context_precision],
187
+ raise_exceptions=False,
188
+ )
189
+ df = scores.to_pandas()
190
+ return {
191
+ "faithfulness": float(df["faithfulness"].iloc[0]) if "faithfulness" in df else None,
192
+ "answer_relevancy": float(df["answer_relevancy"].iloc[0]) if "answer_relevancy" in df else None,
193
+ "context_precision": float(df["context_precision"].iloc[0]) if "context_precision" in df else None,
194
+ }
195
+
196
+ except Exception as exc:
197
+ logger.warning("RAGAS evaluation failed: %s", exc)
198
+ return {}
199
+
200
+ # ── Custom metrics (no LLM required) ──────────────────────
201
+
202
+ @staticmethod
203
+ def _context_utilisation(answer: str, chunks: list[RetrievedChunk]) -> float:
204
+ """
205
+ Fraction of retrieved chunks cited in the answer.
206
+ Looks for inline [N] citation markers.
207
+ """
208
+ if not chunks:
209
+ return 0.0
210
+ cited_indices = set(int(n) for n in re.findall(r"\[(\d+)\]", answer))
211
+ cited = sum(1 for i in range(1, len(chunks) + 1) if i in cited_indices)
212
+ return round(cited / len(chunks), 3)
213
+
214
+ @staticmethod
215
+ def _mean_chunk_score(chunks: list[RetrievedChunk]) -> float:
216
+ """Average retrieval score of the final chunks."""
217
+ if not chunks:
218
+ return 0.0
219
+ return round(sum(c.score for c in chunks) / len(chunks), 3)
220
+
221
+ @staticmethod
222
+ def _check_ragas() -> bool:
223
+ try:
224
+ import ragas # type: ignore # noqa: F401
225
+ import datasets # type: ignore # noqa: F401
226
+ return True
227
+ except ImportError:
228
+ return False
evaluation/store.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cortex RAG β€” Evaluation Store (SQLite)
3
+
4
+ Two tables:
5
+ query_logs β€” one row per query: routing, CRAG grade, latency, chunk scores
6
+ eval_metrics β€” one row per query: RAGAS scores (written async after generation)
7
+
8
+ SQLite is the right choice here: zero infrastructure, works on Railway/Render
9
+ out of the box, and a dashboard corpus of ~10k queries fits in <50MB.
10
+ Swap to Postgres trivially later by changing the connection string.
11
+
12
+ The store is intentionally append-only. No deletes, no updates.
13
+ This preserves the full history for trend analysis in the dashboard.
14
+ """
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import logging
19
+ import sqlite3
20
+ import time
21
+ from contextlib import contextmanager
22
+ from dataclasses import dataclass
23
+ from pathlib import Path
24
+ from typing import Optional
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ _DEFAULT_DB_PATH = Path("data/cortex_eval.db")
29
+
30
+ # ── Schema ─────────────────────────────────────────────────────
31
+
32
+ _DDL = """
33
+ CREATE TABLE IF NOT EXISTS query_logs (
34
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
35
+ timestamp REAL NOT NULL,
36
+ query TEXT NOT NULL,
37
+ intent TEXT,
38
+ strategies TEXT, -- JSON list
39
+ retriever_hits TEXT, -- JSON dict
40
+ crag_grade TEXT,
41
+ crag_rewritten INTEGER DEFAULT 0, -- bool
42
+ web_search_used INTEGER DEFAULT 0, -- bool
43
+ num_chunks INTEGER DEFAULT 0,
44
+ top_chunk_score REAL DEFAULT 0.0,
45
+ latency_ms REAL DEFAULT 0.0,
46
+ model TEXT,
47
+ extractor TEXT
48
+ );
49
+
50
+ CREATE TABLE IF NOT EXISTS eval_metrics (
51
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
52
+ query_log_id INTEGER NOT NULL REFERENCES query_logs(id),
53
+ timestamp REAL NOT NULL,
54
+ faithfulness REAL, -- 0-1: does answer contradict context?
55
+ answer_relevancy REAL, -- 0-1: does answer address the question?
56
+ context_precision REAL, -- 0-1: are retrieved chunks relevant?
57
+ context_utilisation REAL, -- 0-1: fraction of chunks cited in answer
58
+ mean_chunk_score REAL -- average retrieval score of final chunks
59
+ );
60
+
61
+ CREATE INDEX IF NOT EXISTS idx_query_logs_ts ON query_logs(timestamp);
62
+ CREATE INDEX IF NOT EXISTS idx_eval_metrics_id ON eval_metrics(query_log_id);
63
+ """
64
+
65
+
66
+ # ── Dataclasses ────────────────────────────────────────────────
67
+
68
+ @dataclass
69
+ class QueryLogEntry:
70
+ query: str
71
+ intent: str = ""
72
+ strategies: list[str] = None
73
+ retriever_hits: dict = None
74
+ crag_grade: str = ""
75
+ crag_rewritten: bool = False
76
+ web_search_used: bool = False
77
+ num_chunks: int = 0
78
+ top_chunk_score: float = 0.0
79
+ latency_ms: float = 0.0
80
+ model: str = ""
81
+ extractor: str = ""
82
+
83
+ def __post_init__(self):
84
+ if self.strategies is None:
85
+ self.strategies = []
86
+ if self.retriever_hits is None:
87
+ self.retriever_hits = {}
88
+
89
+
90
+ @dataclass
91
+ class EvalMetricEntry:
92
+ query_log_id: int
93
+ faithfulness: Optional[float] = None
94
+ answer_relevancy: Optional[float] = None
95
+ context_precision: Optional[float] = None
96
+ context_utilisation: Optional[float] = None
97
+ mean_chunk_score: Optional[float] = None
98
+
99
+
100
+ # ── Store ──────────────────────────────────────────────────────
101
+
102
+ class EvalStore:
103
+ """
104
+ Thread-safe SQLite-backed store for query logs and eval metrics.
105
+
106
+ Usage:
107
+ store = EvalStore()
108
+ log_id = store.log_query(entry)
109
+ store.log_metrics(EvalMetricEntry(query_log_id=log_id, faithfulness=0.92, ...))
110
+ """
111
+
112
+ def __init__(self, db_path: str | Path = _DEFAULT_DB_PATH) -> None:
113
+ self._path = Path(db_path)
114
+ self._path.parent.mkdir(parents=True, exist_ok=True)
115
+ self._init_db()
116
+
117
+ # ── Write ──────────────────────────────────────────────────
118
+
119
+ def log_query(self, entry: QueryLogEntry) -> int:
120
+ """Insert a query log row. Returns the new row id."""
121
+ with self._conn() as conn:
122
+ cur = conn.execute(
123
+ """INSERT INTO query_logs
124
+ (timestamp, query, intent, strategies, retriever_hits,
125
+ crag_grade, crag_rewritten, web_search_used,
126
+ num_chunks, top_chunk_score, latency_ms, model, extractor)
127
+ VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?)""",
128
+ (
129
+ time.time(),
130
+ entry.query,
131
+ entry.intent,
132
+ json.dumps(entry.strategies),
133
+ json.dumps(entry.retriever_hits),
134
+ entry.crag_grade,
135
+ int(entry.crag_rewritten),
136
+ int(entry.web_search_used),
137
+ entry.num_chunks,
138
+ entry.top_chunk_score,
139
+ entry.latency_ms,
140
+ entry.model,
141
+ entry.extractor,
142
+ ),
143
+ )
144
+ return cur.lastrowid
145
+
146
+ def log_metrics(self, entry: EvalMetricEntry) -> None:
147
+ """Insert an eval_metrics row."""
148
+ with self._conn() as conn:
149
+ conn.execute(
150
+ """INSERT INTO eval_metrics
151
+ (query_log_id, timestamp, faithfulness, answer_relevancy,
152
+ context_precision, context_utilisation, mean_chunk_score)
153
+ VALUES (?,?,?,?,?,?,?)""",
154
+ (
155
+ entry.query_log_id,
156
+ time.time(),
157
+ entry.faithfulness,
158
+ entry.answer_relevancy,
159
+ entry.context_precision,
160
+ entry.context_utilisation,
161
+ entry.mean_chunk_score,
162
+ ),
163
+ )
164
+
165
+ # ── Read ───────────────────────────────────────────────────
166
+
167
+ def get_recent_queries(self, limit: int = 100) -> list[dict]:
168
+ """Last N query logs joined with their eval metrics (if available)."""
169
+ with self._conn() as conn:
170
+ rows = conn.execute(
171
+ """SELECT q.id, q.timestamp, q.query, q.intent, q.strategies,
172
+ q.crag_grade, q.web_search_used, q.num_chunks,
173
+ q.top_chunk_score, q.latency_ms,
174
+ e.faithfulness, e.answer_relevancy,
175
+ e.context_precision, e.context_utilisation,
176
+ e.mean_chunk_score
177
+ FROM query_logs q
178
+ LEFT JOIN eval_metrics e ON e.query_log_id = q.id
179
+ ORDER BY q.timestamp DESC
180
+ LIMIT ?""",
181
+ (limit,),
182
+ ).fetchall()
183
+ return [self._row_to_dict(r) for r in rows]
184
+
185
+ def get_metric_timeseries(self, days: int = 7) -> list[dict]:
186
+ """
187
+ Hourly-bucketed metric averages over the last N days.
188
+ Used for the trend line chart in the dashboard.
189
+ """
190
+ since = time.time() - days * 86400
191
+ with self._conn() as conn:
192
+ rows = conn.execute(
193
+ """SELECT
194
+ CAST((q.timestamp - ?) / 3600 AS INTEGER) AS hour_bucket,
195
+ AVG(e.faithfulness) AS faithfulness,
196
+ AVG(e.answer_relevancy) AS answer_relevancy,
197
+ AVG(e.context_precision) AS context_precision,
198
+ AVG(e.mean_chunk_score) AS mean_chunk_score,
199
+ COUNT(*) AS query_count
200
+ FROM query_logs q
201
+ JOIN eval_metrics e ON e.query_log_id = q.id
202
+ WHERE q.timestamp > ?
203
+ GROUP BY hour_bucket
204
+ ORDER BY hour_bucket""",
205
+ (since, since),
206
+ ).fetchall()
207
+ return [dict(zip(
208
+ ["hour_bucket", "faithfulness", "answer_relevancy",
209
+ "context_precision", "mean_chunk_score", "query_count"], r
210
+ )) for r in rows]
211
+
212
+ def get_summary_stats(self) -> dict:
213
+ """Aggregate stats for the dashboard header metrics."""
214
+ with self._conn() as conn:
215
+ total = conn.execute("SELECT COUNT(*) FROM query_logs").fetchone()[0]
216
+ with_metrics = conn.execute("SELECT COUNT(*) FROM eval_metrics").fetchone()[0]
217
+ avgs = conn.execute(
218
+ """SELECT AVG(faithfulness), AVG(answer_relevancy),
219
+ AVG(context_precision), AVG(mean_chunk_score)
220
+ FROM eval_metrics"""
221
+ ).fetchone()
222
+ grade_dist = conn.execute(
223
+ """SELECT crag_grade, COUNT(*) as cnt
224
+ FROM query_logs WHERE crag_grade != ''
225
+ GROUP BY crag_grade"""
226
+ ).fetchall()
227
+ strategy_dist = conn.execute(
228
+ """SELECT strategies, COUNT(*) as cnt
229
+ FROM query_logs GROUP BY strategies"""
230
+ ).fetchall()
231
+ avg_latency = conn.execute(
232
+ "SELECT AVG(latency_ms) FROM query_logs WHERE latency_ms > 0"
233
+ ).fetchone()[0]
234
+
235
+ return {
236
+ "total_queries": total,
237
+ "evaluated_queries": with_metrics,
238
+ "avg_faithfulness": round(avgs[0] or 0, 3),
239
+ "avg_answer_relevancy": round(avgs[1] or 0, 3),
240
+ "avg_context_precision": round(avgs[2] or 0, 3),
241
+ "avg_chunk_score": round(avgs[3] or 0, 3),
242
+ "avg_latency_ms": round(avg_latency or 0, 1),
243
+ "crag_grade_dist": {r[0]: r[1] for r in grade_dist},
244
+ "strategy_dist": {r[0]: r[1] for r in strategy_dist},
245
+ }
246
+
247
+ # ── Init ───────────────────────────────────────────────────
248
+
249
+ def _init_db(self) -> None:
250
+ with self._conn() as conn:
251
+ conn.executescript(_DDL)
252
+ logger.info("EvalStore ready at %s", self._path)
253
+
254
+ @contextmanager
255
+ def _conn(self):
256
+ conn = sqlite3.connect(self._path, timeout=10, check_same_thread=False)
257
+ conn.row_factory = sqlite3.Row
258
+ try:
259
+ yield conn
260
+ conn.commit()
261
+ except Exception:
262
+ conn.rollback()
263
+ raise
264
+ finally:
265
+ conn.close()
266
+
267
+ @staticmethod
268
+ def _row_to_dict(row) -> dict:
269
+ d = dict(row)
270
+ for key in ("strategies",):
271
+ if d.get(key):
272
+ try:
273
+ d[key] = json.loads(d[key])
274
+ except Exception:
275
+ pass
276
+ return d
generation/crag.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cortex RAG β€” Corrective RAG (CRAG) Gate (Phase 3)
3
+
4
+ The problem CRAG solves
5
+ ────────────────────────
6
+ Standard RAG always passes retrieved chunks to the LLM, even when:
7
+ - The query is ambiguous and the retrieved chunks are off-topic
8
+ - The knowledge base simply doesn't contain the answer
9
+ - The retrieved chunks contradict each other
10
+
11
+ In all three cases, the LLM will either hallucinate or produce a
12
+ confused answer. CRAG adds a grading step BEFORE generation:
13
+
14
+ β”Œβ”€β”€β”€ GOOD ────► Generator (proceed normally)
15
+ Query β†’ Retrieve ─
16
+ β”œβ”€β”€β”€ POOR ────► Rewrite query β†’ Re-retrieve β†’ Generator
17
+ └─── ABSENT ──► Web search fallback β†’ Generator
18
+
19
+ Grading
20
+ ────────
21
+ An LLM-as-judge evaluates (query, retrieved_chunks) and returns:
22
+ {
23
+ "grade": "GOOD" | "POOR" | "ABSENT",
24
+ "relevance_score": 0.0–1.0,
25
+ "has_sufficient_context": true | false,
26
+ "reasoning": "..."
27
+ }
28
+
29
+ Grade definitions:
30
+ GOOD β€” chunks are relevant and sufficient for the query
31
+ POOR β€” chunks are partially relevant; try rewriting the query
32
+ ABSENT β€” knowledge base clearly doesn't contain the answer;
33
+ fall back to web search
34
+
35
+ Query rewriting
36
+ ────────────────
37
+ When grade == POOR, we expand the query using chain-of-thought:
38
+ the grader's `reasoning` field (why did retrieval fail?) is fed
39
+ back as context for a rewrite prompt. This makes the rewrite
40
+ semantically targeted, not just rephrased.
41
+
42
+ Web search fallback
43
+ ────────────────────
44
+ When grade == ABSENT, we call Tavily (preferred) or DuckDuckGo
45
+ (no API key needed) and package the top-3 web results as synthetic
46
+ RetrievedChunk objects with source="web_search". These flow into
47
+ the same generator unchanged.
48
+ """
49
+ from __future__ import annotations
50
+
51
+ import json
52
+ import logging
53
+ import re
54
+ from dataclasses import dataclass
55
+ from enum import Enum
56
+ from typing import Optional
57
+
58
+ from config import get_settings
59
+ from retrieval.dense import RetrievedChunk
60
+
61
+ logger = logging.getLogger(__name__)
62
+
63
+
64
+ # ── Grade enum ─────────────────────────────────────────────────
65
+
66
+ class RetrievalGrade(str, Enum):
67
+ GOOD = "GOOD" # proceed to generation
68
+ POOR = "POOR" # rewrite query and re-retrieve
69
+ ABSENT = "ABSENT" # fall back to web search
70
+
71
+
72
+ # ── Result dataclass ───────────────────────────────────────────
73
+
74
+ @dataclass
75
+ class CRAGResult:
76
+ grade: RetrievalGrade
77
+ relevance_score: float
78
+ has_sufficient_context: bool
79
+ reasoning: str
80
+ final_chunks: list[RetrievedChunk] # chunks to pass to generator
81
+ rewritten_query: Optional[str] = None # set if grade was POOR
82
+ web_search_used: bool = False
83
+
84
+
85
+ # ── Prompt templates ───────────────────────────────────────────
86
+
87
+ _GRADER_PROMPT = """\
88
+ You are a retrieval quality judge. Given a user query and retrieved passages,
89
+ assess whether the passages contain sufficient information to answer the query.
90
+
91
+ Return ONLY a JSON object in this exact format (no markdown, no preamble):
92
+ {{
93
+ "grade": "<GOOD|POOR|ABSENT>",
94
+ "relevance_score": <float 0.0-1.0>,
95
+ "has_sufficient_context": <true|false>,
96
+ "reasoning": "<one sentence explaining your assessment>"
97
+ }}
98
+
99
+ Grades:
100
+ GOOD β€” passages are clearly relevant and contain enough information to answer
101
+ POOR β€” passages are partially relevant but incomplete or off-topic; retrieval should be retried
102
+ ABSENT β€” the knowledge base clearly does not contain information about this query
103
+
104
+ User query: {query}
105
+
106
+ Retrieved passages:
107
+ {passages}
108
+ """
109
+
110
+ _REWRITE_PROMPT = """\
111
+ A retrieval system failed to find good results for the following query.
112
+ The grader's feedback explains why the results were poor.
113
+
114
+ Original query: {query}
115
+ Grader feedback: {reasoning}
116
+
117
+ Rewrite the query to be more specific and likely to retrieve better results.
118
+ Apply these strategies: expand acronyms, add domain context, use alternative terms.
119
+
120
+ Return ONLY the rewritten query string, no explanation.
121
+ """
122
+
123
+
124
+ # ── CRAG Gate ──────────────────────────────────────────────────
125
+
126
+ class CRAGGate:
127
+ """
128
+ Corrective RAG gate that sits between retrieval and generation.
129
+
130
+ Usage (in orchestrator):
131
+ crag = CRAGGate()
132
+ result = crag.evaluate(
133
+ query=user_query,
134
+ chunks=retrieved_chunks,
135
+ retriever_fn=retriever.retrieve, # callable for re-retrieval
136
+ )
137
+ # result.final_chunks β†’ pass to generator
138
+ # result.grade β†’ log for evaluation dashboard
139
+ """
140
+
141
+ def __init__(self) -> None:
142
+ self._llm = None
143
+
144
+ # ── Public API ─────────────────────────────────────────────
145
+
146
+ def evaluate(
147
+ self,
148
+ query: str,
149
+ chunks: list[RetrievedChunk],
150
+ retriever_fn: Optional[callable] = None,
151
+ max_retries: int = 1,
152
+ ) -> CRAGResult:
153
+ """
154
+ Grade retrieved chunks and apply corrective action if needed.
155
+
156
+ Args:
157
+ query: the user's original query
158
+ chunks: chunks returned by the retrieval pipeline
159
+ retriever_fn: callable(query: str) β†’ list[RetrievedChunk]
160
+ used for re-retrieval on POOR grade
161
+ max_retries: max number of rewrite+re-retrieve cycles
162
+ """
163
+ # Grade the initial retrieval
164
+ grade_result = self._grade(query, chunks)
165
+ logger.info(
166
+ "CRAG grade: %s (score=%.2f, sufficient=%s) β€” %s",
167
+ grade_result["grade"],
168
+ grade_result["relevance_score"],
169
+ grade_result["has_sufficient_context"],
170
+ grade_result["reasoning"][:80],
171
+ )
172
+
173
+ grade = RetrievalGrade(grade_result["grade"])
174
+
175
+ # ── GOOD: pass through unchanged ──────────────────────
176
+ if grade == RetrievalGrade.GOOD:
177
+ return CRAGResult(
178
+ grade=grade,
179
+ relevance_score=grade_result["relevance_score"],
180
+ has_sufficient_context=True,
181
+ reasoning=grade_result["reasoning"],
182
+ final_chunks=chunks,
183
+ )
184
+
185
+ # ── POOR: rewrite query and re-retrieve ───────────────
186
+ if grade == RetrievalGrade.POOR and retriever_fn and max_retries > 0:
187
+ rewritten = self._rewrite_query(query, grade_result["reasoning"])
188
+ logger.info("CRAG rewrite: '%s' β†’ '%s'", query[:50], rewritten[:50])
189
+
190
+ try:
191
+ new_chunks = retriever_fn(rewritten)
192
+ # Re-grade the new results (once β€” no infinite loop)
193
+ new_grade = self._grade(rewritten, new_chunks)
194
+ return CRAGResult(
195
+ grade=RetrievalGrade(new_grade["grade"]),
196
+ relevance_score=new_grade["relevance_score"],
197
+ has_sufficient_context=new_grade["has_sufficient_context"],
198
+ reasoning=new_grade["reasoning"],
199
+ final_chunks=new_chunks or chunks, # fall back if retry also empty
200
+ rewritten_query=rewritten,
201
+ )
202
+ except Exception as exc:
203
+ logger.warning("Re-retrieval after rewrite failed: %s", exc)
204
+ # Fall through to returning original chunks with POOR grade
205
+ return CRAGResult(
206
+ grade=grade,
207
+ relevance_score=grade_result["relevance_score"],
208
+ has_sufficient_context=False,
209
+ reasoning=grade_result["reasoning"],
210
+ final_chunks=chunks,
211
+ rewritten_query=rewritten,
212
+ )
213
+
214
+ # ── ABSENT: web search fallback ────────────────────────
215
+ if grade == RetrievalGrade.ABSENT:
216
+ web_chunks = self._web_search_fallback(query)
217
+ if web_chunks:
218
+ return CRAGResult(
219
+ grade=grade,
220
+ relevance_score=0.0,
221
+ has_sufficient_context=True,
222
+ reasoning=grade_result["reasoning"],
223
+ final_chunks=web_chunks,
224
+ web_search_used=True,
225
+ )
226
+ # Web search also failed β€” return original chunks with warning
227
+ return CRAGResult(
228
+ grade=grade,
229
+ relevance_score=0.0,
230
+ has_sufficient_context=False,
231
+ reasoning=f"Knowledge base: {grade_result['reasoning']}. Web search also returned no results.",
232
+ final_chunks=chunks,
233
+ )
234
+
235
+ # Default: return original chunks unchanged
236
+ return CRAGResult(
237
+ grade=grade,
238
+ relevance_score=grade_result["relevance_score"],
239
+ has_sufficient_context=grade_result["has_sufficient_context"],
240
+ reasoning=grade_result["reasoning"],
241
+ final_chunks=chunks,
242
+ )
243
+
244
+ # ── LLM grader ────────────────────────────────────────────
245
+
246
+ def _grade(self, query: str, chunks: list[RetrievedChunk]) -> dict:
247
+ """Call LLM to grade retrieval quality. Returns parsed dict."""
248
+ if not chunks:
249
+ return {
250
+ "grade": "ABSENT",
251
+ "relevance_score": 0.0,
252
+ "has_sufficient_context": False,
253
+ "reasoning": "No chunks were retrieved.",
254
+ }
255
+
256
+ passages = "\n\n".join(
257
+ f"[{i}] {c.title}: {c.text[:400]}"
258
+ for i, c in enumerate(chunks[:5], 1)
259
+ )
260
+
261
+ try:
262
+ client = self._get_llm()
263
+ cfg = get_settings()
264
+ response = client.chat.completions.create(
265
+ model=cfg.groq_model,
266
+ messages=[{
267
+ "role": "user",
268
+ "content": _GRADER_PROMPT.format(query=query, passages=passages),
269
+ }],
270
+ temperature=0.0,
271
+ max_tokens=200,
272
+ )
273
+ raw = response.choices[0].message.content or "{}"
274
+ return self._parse_grade(raw)
275
+
276
+ except Exception as exc:
277
+ logger.warning("CRAG grader LLM call failed: %s", exc)
278
+ # Safe default: assume GOOD to avoid blocking the pipeline
279
+ return {
280
+ "grade": "GOOD",
281
+ "relevance_score": 0.5,
282
+ "has_sufficient_context": True,
283
+ "reasoning": f"Grader unavailable ({exc}); passing through.",
284
+ }
285
+
286
+ def _parse_grade(self, raw: str) -> dict:
287
+ raw = raw.strip()
288
+ if raw.startswith("```"):
289
+ raw = re.sub(r"^```[a-z]*\n?", "", raw)
290
+ raw = re.sub(r"\n?```$", "", raw)
291
+ try:
292
+ data = json.loads(raw)
293
+ except json.JSONDecodeError:
294
+ return {
295
+ "grade": "GOOD", "relevance_score": 0.5,
296
+ "has_sufficient_context": True, "reasoning": "Parse error.",
297
+ }
298
+
299
+ grade_str = data.get("grade", "GOOD").upper()
300
+ if grade_str not in {"GOOD", "POOR", "ABSENT"}:
301
+ grade_str = "GOOD"
302
+
303
+ return {
304
+ "grade": grade_str,
305
+ "relevance_score": float(data.get("relevance_score", 0.5)),
306
+ "has_sufficient_context": bool(data.get("has_sufficient_context", True)),
307
+ "reasoning": str(data.get("reasoning", "")),
308
+ }
309
+
310
+ # ── Query rewriter ────────────────────────────────────────
311
+
312
+ def _rewrite_query(self, original_query: str, reasoning: str) -> str:
313
+ try:
314
+ client = self._get_llm()
315
+ cfg = get_settings()
316
+ response = client.chat.completions.create(
317
+ model=cfg.groq_model,
318
+ messages=[{
319
+ "role": "user",
320
+ "content": _REWRITE_PROMPT.format(
321
+ query=original_query, reasoning=reasoning
322
+ ),
323
+ }],
324
+ temperature=0.3,
325
+ max_tokens=128,
326
+ )
327
+ rewritten = (response.choices[0].message.content or "").strip()
328
+ return rewritten if rewritten else original_query
329
+ except Exception as exc:
330
+ logger.warning("Query rewrite failed: %s", exc)
331
+ return original_query
332
+
333
+ # ── Web search fallback ───────────────────────────────────
334
+
335
+ def _web_search_fallback(self, query: str) -> list[RetrievedChunk]:
336
+ """
337
+ Try Tavily first (better quality), then DuckDuckGo (no API key).
338
+ Returns synthetic RetrievedChunk objects from web results.
339
+ """
340
+ chunks = self._tavily_search(query) or self._duckduckgo_search(query)
341
+ if chunks:
342
+ logger.info("CRAG web fallback: %d results for '%s'", len(chunks), query[:50])
343
+ return chunks
344
+
345
+ def _tavily_search(self, query: str) -> list[RetrievedChunk]:
346
+ try:
347
+ from tavily import TavilyClient # type: ignore
348
+ cfg = get_settings()
349
+ api_key = cfg.tavily_api_key
350
+ if not api_key:
351
+ return []
352
+ client = TavilyClient(api_key=api_key)
353
+ results = client.search(query, max_results=3)
354
+ return [
355
+ self._web_result_to_chunk(r.get("content", ""), r.get("url", ""), r.get("title", "Web"))
356
+ for r in results.get("results", [])
357
+ if r.get("content")
358
+ ]
359
+ except Exception:
360
+ return []
361
+
362
+ def _duckduckgo_search(self, query: str) -> list[RetrievedChunk]:
363
+ try:
364
+ from duckduckgo_search import DDGS # type: ignore
365
+ results = []
366
+ with DDGS() as ddgs:
367
+ for r in ddgs.text(query, max_results=3):
368
+ results.append(
369
+ self._web_result_to_chunk(
370
+ r.get("body", ""), r.get("href", ""), r.get("title", "Web")
371
+ )
372
+ )
373
+ return results
374
+ except Exception:
375
+ return []
376
+
377
+ @staticmethod
378
+ def _web_result_to_chunk(text: str, url: str, title: str) -> RetrievedChunk:
379
+ import hashlib
380
+ cid = hashlib.sha256(url.encode()).hexdigest()[:16]
381
+ return RetrievedChunk(
382
+ chunk_id=cid,
383
+ doc_id="web",
384
+ source=url,
385
+ title=title,
386
+ text=text[:1500],
387
+ parent_text=text[:1500],
388
+ chunk_index=0,
389
+ score=0.6, # neutral score for web results
390
+ retriever="web_search",
391
+ )
392
+
393
+ # ── Groq client ───────────────────────────────────────────
394
+
395
+ def _get_llm(self):
396
+ if self._llm is None:
397
+ cfg = get_settings()
398
+ if not cfg.groq_api_key:
399
+ raise RuntimeError("GROQ_API_KEY not set")
400
+ from groq import Groq # type: ignore
401
+ self._llm = Groq(api_key=cfg.groq_api_key)
402
+ return self._llm
ingestion/pipeline.py CHANGED
@@ -22,6 +22,7 @@ from ingestion.document_loader import Document, DocumentLoader
22
  from retrieval.embedder import Embedder
23
  from retrieval.dense import MilvusStore
24
  from retrieval.bm25 import BM25Retriever
 
25
 
26
  logger = logging.getLogger(__name__)
27
 
@@ -43,12 +44,15 @@ class IngestionPipeline:
43
  embedder: Optional[Embedder] = None,
44
  store: Optional[MilvusStore] = None,
45
  bm25: Optional[BM25Retriever] = None,
 
 
46
  ) -> None:
47
  self._loader = loader or DocumentLoader()
48
  self._embedder = embedder or Embedder()
49
  self._chunker = chunker or SemanticChunker(embedder=self._embedder)
50
  self._store = store or MilvusStore(embedder=self._embedder)
51
  self._bm25 = bm25 or BM25Retriever()
 
52
 
53
  # ── Public ─────────────────────────────────────────────────
54
 
@@ -133,7 +137,16 @@ class IngestionPipeline:
133
  except Exception as exc:
134
  logger.error("BM25 indexing failed: %s", exc)
135
  stats["errors"].append({"source": "bm25_index", "error": str(exc)})
136
-
 
 
 
 
 
 
 
 
 
137
  elapsed = time.perf_counter() - t0
138
  logger.info(
139
  "Ingestion complete in %.1fs β€” %d docs, %d chunks stored.",
 
22
  from retrieval.embedder import Embedder
23
  from retrieval.dense import MilvusStore
24
  from retrieval.bm25 import BM25Retriever
25
+ from retrieval.graph_builder import KnowledgeGraphBuilder
26
 
27
  logger = logging.getLogger(__name__)
28
 
 
44
  embedder: Optional[Embedder] = None,
45
  store: Optional[MilvusStore] = None,
46
  bm25: Optional[BM25Retriever] = None,
47
+ graph: Optional[KnowledgeGraphBuilder] = None,
48
+
49
  ) -> None:
50
  self._loader = loader or DocumentLoader()
51
  self._embedder = embedder or Embedder()
52
  self._chunker = chunker or SemanticChunker(embedder=self._embedder)
53
  self._store = store or MilvusStore(embedder=self._embedder)
54
  self._bm25 = bm25 or BM25Retriever()
55
+ self._graph = graph or KnowledgeGraphBuilder()
56
 
57
  # ── Public ─────────────────────────────────────────────────
58
 
 
137
  except Exception as exc:
138
  logger.error("BM25 indexing failed: %s", exc)
139
  stats["errors"].append({"source": "bm25_index", "error": str(exc)})
140
+
141
+ # ── Build knowledge graph (NER + relations) β€” Phase 3 ──
142
+ try:
143
+ graph_stats = self._graph.process_chunks(all_chunks)
144
+ stats["graph_entities"] = graph_stats.get("entities", 0)
145
+ stats["graph_triples"] = graph_stats.get("triples", 0)
146
+ except Exception as exc:
147
+ logger.error("Graph extraction failed: %s", exc)
148
+ stats["errors"].append({"source": "graph_build", "error": str(exc)})
149
+
150
  elapsed = time.perf_counter() - t0
151
  logger.info(
152
  "Ingestion complete in %.1fs β€” %d docs, %d chunks stored.",
retrieval/cache.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cortex RAG β€” Retrieval Cache (Redis, Phase 4)
3
+
4
+ What gets cached
5
+ ─────────────────
6
+ The output of the full retrieval pipeline β€” after RRF fusion and
7
+ cross-encoder reranking β€” is serialised and stored in Redis with a
8
+ configurable TTL (default 1 hour).
9
+
10
+ Cache key: SHA-256 of (query.lower().strip() + str(top_k))
11
+ This means the same query with different capitalisation or trailing
12
+ spaces hits the same cache entry, which is almost always correct for RAG.
13
+
14
+ What does NOT get cached
15
+ ─────────────────────────
16
+ CRAG evaluation and generation are NOT cached. The CRAG grade depends
17
+ on the current state of the knowledge base (which changes after ingestion),
18
+ and generation is fast enough (streaming) that caching it adds complexity
19
+ without meaningful latency savings.
20
+
21
+ Graceful degradation
22
+ ─────────────────────
23
+ If Redis is unreachable on startup, the cache silently disables itself
24
+ and logs a warning. Every query falls through to the live retrieval
25
+ pipeline unchanged. No exceptions surface to the user.
26
+
27
+ This means you can develop without Redis running locally and only enable
28
+ it in production (Railway, Render) where Redis add-ons are available.
29
+ """
30
+ from __future__ import annotations
31
+
32
+ import hashlib
33
+ import json
34
+ import logging
35
+ from typing import Optional
36
+
37
+ from retrieval.dense import RetrievedChunk
38
+ from retrieval.orchestrator import MultiStrategyRetriever, RetrievalResult
39
+ from retrieval.router import QueryIntent, RoutingDecision
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ def _make_cache_key(query: str, top_k: int) -> str:
45
+ raw = f"{query.lower().strip()}:{top_k}"
46
+ return "cortex:retrieval:" + hashlib.sha256(raw.encode()).hexdigest()[:24]
47
+
48
+
49
+ def _serialise_result(result: RetrievalResult) -> str:
50
+ """JSON-serialise a RetrievalResult for Redis storage."""
51
+ return json.dumps({
52
+ "chunks": [
53
+ {
54
+ "chunk_id": c.chunk_id,
55
+ "doc_id": c.doc_id,
56
+ "source": c.source,
57
+ "title": c.title,
58
+ "text": c.text,
59
+ "parent_text": c.parent_text,
60
+ "chunk_index": c.chunk_index,
61
+ "score": c.score,
62
+ "retriever": c.retriever,
63
+ }
64
+ for c in result.chunks
65
+ ],
66
+ "decision": {
67
+ "intent": result.decision.intent.value,
68
+ "strategies": result.decision.strategies,
69
+ "confidence": result.decision.confidence,
70
+ "reasoning": result.decision.reasoning,
71
+ },
72
+ "retriever_hits": result.retriever_hits,
73
+ })
74
+
75
+
76
+ def _deserialise_result(raw: str) -> RetrievalResult:
77
+ """Reconstruct a RetrievalResult from its JSON representation."""
78
+ data = json.loads(raw)
79
+
80
+ chunks = [
81
+ RetrievedChunk(
82
+ chunk_id=c["chunk_id"],
83
+ doc_id=c["doc_id"],
84
+ source=c["source"],
85
+ title=c["title"],
86
+ text=c["text"],
87
+ parent_text=c["parent_text"],
88
+ chunk_index=c["chunk_index"],
89
+ score=c["score"],
90
+ retriever=c["retriever"],
91
+ )
92
+ for c in data["chunks"]
93
+ ]
94
+
95
+ d = data["decision"]
96
+ decision = RoutingDecision(
97
+ intent=QueryIntent(d["intent"]),
98
+ strategies=d["strategies"],
99
+ confidence=d["confidence"],
100
+ reasoning=d["reasoning"],
101
+ )
102
+
103
+ return RetrievalResult(
104
+ chunks=chunks,
105
+ decision=decision,
106
+ retriever_hits=data.get("retriever_hits", {}),
107
+ )
108
+
109
+
110
+ class CachedRetriever:
111
+ """
112
+ Drop-in wrapper around MultiStrategyRetriever that adds Redis caching.
113
+
114
+ Usage (replaces MultiStrategyRetriever in api/main.py):
115
+ retriever = CachedRetriever(MultiStrategyRetriever(...))
116
+ result = retriever.retrieve(query)
117
+ print(retriever.cache_stats()) # {"hits": 3, "misses": 7, "enabled": True}
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ inner: MultiStrategyRetriever,
123
+ ttl_seconds: Optional[int] = None,
124
+ ) -> None:
125
+ self._inner = inner
126
+ self._redis = self._connect_redis()
127
+ self._ttl = ttl_seconds or self._default_ttl()
128
+ self._hits = 0
129
+ self._misses = 0
130
+
131
+ # ── Public API (matches MultiStrategyRetriever interface) ──
132
+
133
+ def retrieve(
134
+ self,
135
+ query: str,
136
+ top_k_candidates: Optional[int] = None,
137
+ final_top_k: Optional[int] = None,
138
+ ) -> RetrievalResult:
139
+ """
140
+ Retrieve with cache. Falls through to live retrieval on miss or error.
141
+ """
142
+ from config import get_settings
143
+ cfg = get_settings()
144
+ k = final_top_k or cfg.final_top_k
145
+ key = _make_cache_key(query, k)
146
+
147
+ # ── Cache lookup ───────────────────────────────────────
148
+ if self._redis:
149
+ try:
150
+ cached = self._redis.get(key)
151
+ if cached:
152
+ self._hits += 1
153
+ logger.debug("Cache HIT for query: %s…", query[:40])
154
+ result = _deserialise_result(cached)
155
+ result.from_cache = True
156
+ return result
157
+ except Exception as exc:
158
+ logger.warning("Redis GET failed: %s β€” falling through.", exc)
159
+
160
+ # ── Cache miss: live retrieval ─────────────────────────
161
+ self._misses += 1
162
+ logger.debug("Cache MISS for query: %s…", query[:40])
163
+ result = self._inner.retrieve(query, top_k_candidates, final_top_k)
164
+ result.from_cache = False
165
+
166
+ # ── Write to cache ─────────────────────────────────────
167
+ if self._redis and not result.empty:
168
+ try:
169
+ self._redis.setex(key, self._ttl, _serialise_result(result))
170
+ except Exception as exc:
171
+ logger.warning("Redis SET failed: %s", exc)
172
+
173
+ return result
174
+
175
+ def invalidate(self, query: str, top_k: int) -> bool:
176
+ """Manually invalidate a cache entry (e.g. after re-ingestion)."""
177
+ if not self._redis:
178
+ return False
179
+ try:
180
+ return bool(self._redis.delete(_make_cache_key(query, top_k)))
181
+ except Exception:
182
+ return False
183
+
184
+ def flush_all(self) -> int:
185
+ """Delete all Cortex cache keys. Returns count deleted."""
186
+ if not self._redis:
187
+ return 0
188
+ try:
189
+ keys = self._redis.keys("cortex:retrieval:*")
190
+ if keys:
191
+ return self._redis.delete(*keys)
192
+ return 0
193
+ except Exception:
194
+ return 0
195
+
196
+ def cache_stats(self) -> dict:
197
+ total = self._hits + self._misses
198
+ return {
199
+ "enabled": self._redis is not None,
200
+ "hits": self._hits,
201
+ "misses": self._misses,
202
+ "hit_rate": round(self._hits / total, 3) if total else 0.0,
203
+ "ttl_s": self._ttl,
204
+ }
205
+
206
+ # ── Pass-through for orchestrator methods ──────────────────
207
+
208
+ def index_chunks(self, chunks: list) -> int:
209
+ return self._inner.index_chunks(chunks)
210
+
211
+ def build_graph(self, chunks: list) -> dict:
212
+ return self._inner.build_graph(chunks)
213
+
214
+ @property
215
+ def graph_builder(self):
216
+ return self._inner.graph_builder
217
+
218
+ # ── Redis connection ───────────────────────────────────────
219
+
220
+ @staticmethod
221
+ def _connect_redis():
222
+ from config import get_settings
223
+ cfg = get_settings()
224
+ url = getattr(cfg, "redis_url", "redis://localhost:6379")
225
+ try:
226
+ import redis # type: ignore
227
+ client = redis.from_url(url, socket_connect_timeout=2, decode_responses=True)
228
+ client.ping()
229
+ logger.info("Redis cache connected at %s", url)
230
+ return client
231
+ except ImportError:
232
+ logger.info("redis-py not installed β€” cache disabled. pip install redis")
233
+ return None
234
+ except Exception as exc:
235
+ logger.warning("Redis unavailable (%s) β€” cache disabled.", exc)
236
+ return None
237
+
238
+ @staticmethod
239
+ def _default_ttl() -> int:
240
+ from config import get_settings
241
+ return getattr(get_settings(), "cache_ttl_seconds", 3600)
retrieval/graph_builder.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cortex RAG β€” Knowledge Graph Builder (Phase 3)
3
+
4
+ What this does
5
+ ──────────────
6
+ During ingestion, every chunk is processed to extract:
7
+ 1. Named entities (spaCy NER: PERSON, ORG, WORK_OF_ART, PRODUCT, …)
8
+ 2. Relations (few-shot LLM: subject β†’ predicate β†’ object triples)
9
+
10
+ These are assembled into a NetworkX undirected graph where:
11
+ - Nodes = entities (label + type + first-seen source)
12
+ - Edges = relations (predicate label + list of source chunk_ids)
13
+
14
+ Each node also carries a list of chunk_ids it appeared in, so the
15
+ graph retriever can map entity β†’ chunks without an extra lookup.
16
+
17
+ The graph is persisted as a JSON file (graphs are small β€” a 100-doc
18
+ corpus typically has <10k nodes). On reload the full graph is
19
+ reconstructed in seconds from the JSON.
20
+
21
+ ──────────────
22
+ (Phase 3, refactored)
23
+
24
+ The builder is now responsible ONLY for:
25
+ - spaCy NER (entities are always extracted the same way)
26
+ - Assembling triples into a NetworkX graph
27
+ - Persisting / loading the graph
28
+
29
+ Relation extraction is delegated to a RelationExtractor strategy:
30
+ - REBELExtractor (default) β€” local model, no API calls
31
+ - LLMExtractor β€” Groq, free-form predicates
32
+
33
+ Switch via .env:
34
+ GRAPH_EXTRACTOR=rebel # default, recommended
35
+ GRAPH_EXTRACTOR=llm # original method
36
+
37
+ Or pass explicitly:
38
+ builder = KnowledgeGraphBuilder(extractor=LLMExtractor())
39
+
40
+ """
41
+ from __future__ import annotations
42
+
43
+ import json
44
+ import logging
45
+
46
+ from pathlib import Path
47
+ from typing import Optional
48
+
49
+ import networkx as nx
50
+
51
+ from ingestion.chunker import Chunk
52
+ from retrieval.relation_extractors import (
53
+ RelationExtractor,
54
+ Triple,
55
+ build_extractor,
56
+ )
57
+
58
+ logger = logging.getLogger(__name__)
59
+
60
+ _DEFAULT_GRAPH_PATH = Path("data/knowledge_graph.json")
61
+
62
+ # spaCy entity types we care about for RAG
63
+ _ENTITY_TYPES = {
64
+ "PERSON", "ORG", "GPE", "PRODUCT", "WORK_OF_ART",
65
+ "EVENT", "LAW", "NORP", "FAC", "LOC",
66
+ }
67
+
68
+ class KnowledgeGraphBuilder:
69
+ """
70
+ Builds and maintains the knowledge graph.
71
+
72
+ Usage (at ingestion time):
73
+ # REBEL (default β€” no API calls)
74
+ builder = KnowledgeGraphBuilder()
75
+ builder.process_chunks(chunks)
76
+
77
+ # LLM method (original)
78
+ from retrieval.relation_extractors import LLMExtractor
79
+ builder = KnowledgeGraphBuilder(extractor=LLMExtractor())
80
+ builder.process_chunks(chunks)
81
+
82
+ Usage (at query time):
83
+ builder = KnowledgeGraphBuilder()
84
+ G = builder.graph # loaded from disk automatically
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ graph_path: str | Path = _DEFAULT_GRAPH_PATH,
90
+ extractor: Optional[RelationExtractor] = None,
91
+ ) -> None:
92
+ self._path = Path(graph_path)
93
+ self._graph: nx.Graph = nx.Graph()
94
+ # If no extractor is injected, build_extractor() reads GRAPH_EXTRACTOR from .env
95
+ self._extractor: RelationExtractor = extractor or build_extractor()
96
+ self._nlp = None
97
+ self._load_if_exists()
98
+ logger.info(
99
+ "KnowledgeGraphBuilder ready (extractor=%s)", self._extractor.name
100
+ )
101
+
102
+ # ── Public API ─────────────────────────────────────────────
103
+
104
+ @property
105
+ def graph(self) -> nx.Graph:
106
+ return self._graph
107
+
108
+ @property
109
+ def extractor_name(self) -> str:
110
+ return self._extractor.name
111
+
112
+ def process_chunks(self, chunks: list[Chunk]) -> dict:
113
+ """
114
+ Extract entities and relations from chunks; update and save graph.
115
+ Uses the configured extractor's extract_batch() for efficiency.
116
+ Returns stats dict.
117
+ """
118
+ if not chunks:
119
+ return {"chunks": 0, "entities": 0, "triples": 0, "errors": 0}
120
+
121
+ stats = {"chunks": len(chunks), "entities": 0, "triples": 0, "errors": 0}
122
+
123
+ # ── Batch relation extraction ──────────────────────────
124
+ # REBEL processes all chunks in one forward pass.
125
+ # LLM falls back to sequential (one API call per chunk).
126
+ try:
127
+ triple_map = self._extractor.extract_batch(chunks)
128
+ except Exception as exc:
129
+ logger.error("Batch extraction failed, falling back to sequential: %s", exc)
130
+ triple_map = {}
131
+ for chunk in chunks:
132
+ try:
133
+ triple_map[chunk.chunk_id] = self._extractor.extract(chunk)
134
+ except Exception as e:
135
+ logger.warning("Extraction failed for %s: %s", chunk.chunk_id, e)
136
+ triple_map[chunk.chunk_id] = []
137
+ stats["errors"] += 1
138
+
139
+ # ── Entity extraction + graph update ───────────────────
140
+ for chunk in chunks:
141
+ try:
142
+ entities = self._extract_entities(chunk.text)
143
+ triples = triple_map.get(chunk.chunk_id, [])
144
+
145
+ self._add_entities_to_graph(entities, chunk)
146
+ self._add_triples_to_graph(triples)
147
+
148
+ stats["entities"] += len(entities)
149
+ stats["triples"] += len(triples)
150
+
151
+ except Exception as exc:
152
+ logger.warning("Graph update failed for chunk %s: %s", chunk.chunk_id, exc)
153
+ stats["errors"] += 1
154
+
155
+ self.save()
156
+ logger.info(
157
+ "Graph updated via %s: +%d entities, +%d triples (nodes=%d, edges=%d)",
158
+ self._extractor.name,
159
+ stats["entities"], stats["triples"],
160
+ self._graph.number_of_nodes(), self._graph.number_of_edges(),
161
+ )
162
+ return stats
163
+
164
+ def save(self) -> None:
165
+ self._path.parent.mkdir(parents=True, exist_ok=True)
166
+ data = nx.node_link_data(self._graph)
167
+ with open(self._path, "w") as fh:
168
+ json.dump(data, fh, indent=2)
169
+ logger.debug("Graph saved to %s", self._path)
170
+
171
+ def stats(self) -> dict:
172
+ return {
173
+ "nodes": self._graph.number_of_nodes(),
174
+ "edges": self._graph.number_of_edges(),
175
+ "extractor": self._extractor.name,
176
+ "graph_path": str(self._path),
177
+ }
178
+
179
+ # ── Entity extraction (always spaCy β€” same for both methods) ─
180
+
181
+ def _extract_entities(self, text: str) -> list[tuple[str, str]]:
182
+ nlp = self._get_nlp()
183
+ doc = nlp(text[:10_000])
184
+
185
+ seen: set[str] = set()
186
+ entities: list[tuple[str, str]] = []
187
+ for ent in doc.ents:
188
+ if ent.label_ not in _ENTITY_TYPES:
189
+ continue
190
+ normalised = ent.text.strip().title()
191
+ if normalised in seen or len(normalised) < 2:
192
+ continue
193
+ seen.add(normalised)
194
+ entities.append((normalised, ent.label_))
195
+ return entities
196
+
197
+ # ── Graph construction (shared by both methods) ────────────
198
+
199
+ def _add_entities_to_graph(
200
+ self, entities: list[tuple[str, str]], chunk: Chunk
201
+ ) -> None:
202
+ for label, etype in entities:
203
+ if self._graph.has_node(label):
204
+ existing = self._graph.nodes[label].get("chunk_ids", [])
205
+ if chunk.chunk_id not in existing:
206
+ existing.append(chunk.chunk_id)
207
+ self._graph.nodes[label]["chunk_ids"] = existing
208
+ else:
209
+ self._graph.add_node(
210
+ label,
211
+ entity_type=etype,
212
+ chunk_ids=[chunk.chunk_id],
213
+ source=chunk.source,
214
+ )
215
+
216
+ def _add_triples_to_graph(self, triples: list[Triple]) -> None:
217
+ for triple in triples:
218
+ for node in (triple.subject, triple.object):
219
+ if not self._graph.has_node(node):
220
+ self._graph.add_node(
221
+ node,
222
+ entity_type="UNKNOWN",
223
+ chunk_ids=[],
224
+ source=triple.source,
225
+ extractor=triple.extractor,
226
+ )
227
+
228
+ if self._graph.has_edge(triple.subject, triple.object):
229
+ edge = self._graph[triple.subject][triple.object]
230
+ predicates = edge.get("predicates", [])
231
+ chunk_ids = edge.get("chunk_ids", [])
232
+ if triple.predicate not in predicates:
233
+ predicates.append(triple.predicate)
234
+ if triple.chunk_id not in chunk_ids:
235
+ chunk_ids.append(triple.chunk_id)
236
+ edge["predicates"] = predicates
237
+ edge["chunk_ids"] = chunk_ids
238
+ else:
239
+ self._graph.add_edge(
240
+ triple.subject, triple.object,
241
+ predicates=[triple.predicate],
242
+ chunk_ids=[triple.chunk_id],
243
+ source=triple.source,
244
+ extractor=triple.extractor,
245
+ )
246
+
247
+ # ── Persistence ───────────────────────────────────────────
248
+
249
+ def _load_if_exists(self) -> None:
250
+ if not self._path.exists():
251
+ return
252
+ try:
253
+ with open(self._path) as fh:
254
+ data = json.load(fh)
255
+ self._graph = nx.node_link_graph(data)
256
+ logger.info(
257
+ "Knowledge graph loaded: %d nodes, %d edges",
258
+ self._graph.number_of_nodes(),
259
+ self._graph.number_of_edges(),
260
+ )
261
+ except Exception as exc:
262
+ logger.warning("Failed to load graph (%s) β€” starting fresh.", exc)
263
+
264
+ # ── spaCy ─────────────────────────────────────────────────
265
+
266
+ def _get_nlp(self):
267
+ if self._nlp is None:
268
+ try:
269
+ import spacy # type: ignore
270
+ except ImportError as exc:
271
+ raise RuntimeError("Install spacy: pip install spacy") from exc
272
+ try:
273
+ self._nlp = spacy.load("en_core_web_sm")
274
+ except OSError:
275
+ raise RuntimeError(
276
+ "Run: python -m spacy download en_core_web_sm"
277
+ )
278
+ return self._nlp
retrieval/graph_retriever.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cortex RAG β€” Graph Retriever (Phase 3)
3
+
4
+ How multi-hop retrieval works
5
+ ──────────────────────────────
6
+ Standard dense retrieval can answer: "What is attention?"
7
+ It cannot answer: "Who wrote the attention paper, and what did they later
8
+ build that addresses memory bottlenecks in inference?"
9
+
10
+ That question requires:
11
+ Step 1: Find entity "Attention Is All You Need" in the graph
12
+ Step 2: Follow "authored_by" edges β†’ Vaswani, Shazeer, Parmar, …
13
+ Step 3: Follow those author nodes' other edges β†’
14
+ Shazeer: "introduced" β†’ "Multi-Query Attention"
15
+ Leviathan: "developed" β†’ "Speculative Decoding"
16
+ Step 4: Collect all chunk_ids linked to visited nodes
17
+ Step 5: Fetch those chunks from Milvus β†’ return to RRF pool
18
+
19
+ The BFS depth (default: 2 hops) is the key parameter. 1 hop = only
20
+ direct neighbours; 2 hops = neighbours of neighbours. 3+ hops tends to
21
+ explode in scope and include irrelevant context.
22
+
23
+ Entity matching
24
+ ───────────────
25
+ The query "Who developed PagedAttention?" must match graph nodes like
26
+ "Paged Attention" or "PagedAttention". We do:
27
+ 1. Exact match (case-insensitive)
28
+ 2. Partial match (query entity substring of node label)
29
+ 3. spaCy NER on the query to extract candidate entity strings first
30
+ """
31
+ from __future__ import annotations
32
+
33
+ import logging
34
+ from typing import Optional
35
+
36
+ from retrieval.dense import MilvusStore, RetrievedChunk
37
+ from retrieval.graph_builder import KnowledgeGraphBuilder
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class GraphRetriever:
43
+ """
44
+ Retrieves chunks via knowledge graph traversal.
45
+
46
+ Returns RetrievedChunk objects fetched from Milvus, so they carry
47
+ the same structure as dense/BM25 results and can flow into RRF.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ graph_builder: Optional[KnowledgeGraphBuilder] = None,
53
+ store: Optional[MilvusStore] = None,
54
+ max_hops: int = 2,
55
+ ) -> None:
56
+ self._builder = graph_builder or KnowledgeGraphBuilder()
57
+ self._store = store or MilvusStore()
58
+ self._max_hops = max_hops
59
+ self._nlp = None
60
+
61
+ # ── Public API ─────────────────────────────────────────────
62
+
63
+ def search(self, query: str, top_k: int = 15) -> list[RetrievedChunk]:
64
+ """
65
+ Graph traversal retrieval for a given query.
66
+
67
+ Pipeline:
68
+ 1. Extract named entities from query (spaCy)
69
+ 2. Anchor each entity to matching graph nodes (fuzzy match)
70
+ 3. BFS up to max_hops from anchors
71
+ 4. Collect chunk_ids from all visited nodes + traversed edges
72
+ 5. Fetch chunks from Milvus by chunk_id
73
+ 6. Score by graph centrality (number of graph links to query entities)
74
+ """
75
+ G = self._builder.graph
76
+ if G.number_of_nodes() == 0:
77
+ logger.debug("Graph is empty β€” skipping graph retrieval.")
78
+ return []
79
+
80
+ # 1. Extract query entities
81
+ query_entities = self._extract_query_entities(query)
82
+ if not query_entities:
83
+ logger.debug("No named entities in query β€” skipping graph retrieval.")
84
+ return []
85
+
86
+ logger.debug("Graph query entities: %s", query_entities)
87
+
88
+ # 2. Find anchor nodes
89
+ anchor_nodes = self._find_anchor_nodes(query_entities, G)
90
+ if not anchor_nodes:
91
+ logger.debug("No anchor nodes found in graph.")
92
+ return []
93
+
94
+ logger.debug("Anchor nodes: %s", anchor_nodes)
95
+
96
+ # 3 + 4. BFS traversal β†’ collect chunk_ids
97
+ chunk_id_scores: dict[str, float] = {}
98
+ visited_nodes: set[str] = set()
99
+
100
+ for anchor in anchor_nodes:
101
+ self._bfs_collect(
102
+ G, anchor, self._max_hops,
103
+ chunk_id_scores, visited_nodes
104
+ )
105
+
106
+ if not chunk_id_scores:
107
+ return []
108
+
109
+ # 5. Sort chunk_ids by score and fetch from Milvus
110
+ sorted_ids = sorted(
111
+ chunk_id_scores, key=lambda cid: chunk_id_scores[cid], reverse=True
112
+ )[:top_k]
113
+
114
+ chunks = self._fetch_chunks_from_milvus(sorted_ids, chunk_id_scores)
115
+ logger.info(
116
+ "Graph retriever: %d anchors, %d nodes visited, %d chunks returned",
117
+ len(anchor_nodes), len(visited_nodes), len(chunks)
118
+ )
119
+ return chunks
120
+
121
+ # ── BFS traversal ─────────────────────────────────────────
122
+
123
+ def _bfs_collect(
124
+ self,
125
+ G,
126
+ start_node: str,
127
+ max_hops: int,
128
+ chunk_scores: dict[str, float],
129
+ visited: set[str],
130
+ ) -> None:
131
+ """
132
+ BFS from start_node up to max_hops.
133
+ Scores chunks by hop distance: 1.0 at hop 0, 0.5 at hop 1, 0.25 at hop 2.
134
+ """
135
+ queue: list[tuple[str, int]] = [(start_node, 0)]
136
+ local_visited: set[str] = set()
137
+
138
+ while queue:
139
+ node, depth = queue.pop(0)
140
+ if node in local_visited or depth > max_hops:
141
+ continue
142
+ local_visited.add(node)
143
+ visited.add(node)
144
+
145
+ # Score = 1 / 2^depth (1.0 at anchor, 0.5 one hop away, etc.)
146
+ hop_score = 1.0 / (2 ** depth)
147
+
148
+ # Collect chunk_ids from this node
149
+ node_data = G.nodes[node]
150
+ for cid in node_data.get("chunk_ids", []):
151
+ chunk_scores[cid] = max(chunk_scores.get(cid, 0.0), hop_score)
152
+
153
+ # Collect chunk_ids from edges (relations)
154
+ for neighbour in G.neighbors(node):
155
+ edge_data = G[node][neighbour]
156
+ for cid in edge_data.get("chunk_ids", []):
157
+ chunk_scores[cid] = max(chunk_scores.get(cid, 0.0), hop_score * 0.8)
158
+
159
+ if depth < max_hops:
160
+ queue.append((neighbour, depth + 1))
161
+
162
+ # ── Entity extraction ──────────────────────────────────────
163
+
164
+ def _extract_query_entities(self, query: str) -> list[str]:
165
+ """
166
+ Extract named entities from the query using spaCy NER.
167
+ Falls back to noun chunks if NER finds nothing.
168
+ """
169
+ try:
170
+ nlp = self._get_nlp()
171
+ doc = nlp(query)
172
+ entities = [ent.text.strip().title() for ent in doc.ents if len(ent.text.strip()) > 1]
173
+ if not entities:
174
+ # Fallback: try noun chunks (catches "attention mechanism", etc.)
175
+ entities = [
176
+ chunk.text.strip().title()
177
+ for chunk in doc.noun_chunks
178
+ if len(chunk.text.strip()) > 3
179
+ ]
180
+ return entities
181
+ except Exception as exc:
182
+ logger.debug("Entity extraction failed: %s", exc)
183
+ return []
184
+
185
+ # ── Node matching ─────────────────────────────────────────
186
+
187
+ @staticmethod
188
+ def _find_anchor_nodes(query_entities: list[str], G) -> list[str]:
189
+ """
190
+ Find graph nodes that match query entities.
191
+ Priority: exact match β†’ partial match.
192
+ """
193
+ all_nodes = list(G.nodes())
194
+ lower_nodes = {n.lower(): n for n in all_nodes}
195
+
196
+ anchors: list[str] = []
197
+ for qe in query_entities:
198
+ qe_lower = qe.lower()
199
+
200
+ # Exact match (case-insensitive)
201
+ if qe_lower in lower_nodes:
202
+ anchors.append(lower_nodes[qe_lower])
203
+ continue
204
+
205
+ # Partial match: query entity is substring of a node label
206
+ for node_lower, node in lower_nodes.items():
207
+ if qe_lower in node_lower or node_lower in qe_lower:
208
+ if node not in anchors:
209
+ anchors.append(node)
210
+
211
+ return anchors[:10] # cap to avoid explosion on generic queries
212
+
213
+ # ── Milvus fetch ──────────────────────────────────────────
214
+
215
+ def _fetch_chunks_from_milvus(
216
+ self,
217
+ chunk_ids: list[str],
218
+ scores: dict[str, float],
219
+ ) -> list[RetrievedChunk]:
220
+ """
221
+ Fetch specific chunks from Milvus by chunk_id.
222
+ Tags each chunk with retriever="graph".
223
+ """
224
+ if not chunk_ids:
225
+ return []
226
+
227
+ try:
228
+ # Milvus IN query
229
+ id_list = '", "'.join(chunk_ids)
230
+ expr = f'chunk_id in ["{id_list}"]'
231
+
232
+ coll = self._store._ensure_collection()
233
+ results = coll.query(
234
+ expr=expr,
235
+ output_fields=["chunk_id", "doc_id", "source", "title",
236
+ "text", "parent_text", "chunk_index"],
237
+ limit=len(chunk_ids),
238
+ )
239
+
240
+ chunks: list[RetrievedChunk] = []
241
+ for row in results:
242
+ cid = row["chunk_id"]
243
+ chunks.append(RetrievedChunk(
244
+ chunk_id=cid,
245
+ doc_id=row["doc_id"],
246
+ source=row["source"],
247
+ title=row["title"],
248
+ text=row["text"],
249
+ parent_text=row["parent_text"],
250
+ chunk_index=row["chunk_index"],
251
+ score=scores.get(cid, 0.1),
252
+ retriever="graph",
253
+ ))
254
+ return sorted(chunks, key=lambda c: c.score, reverse=True)
255
+
256
+ except Exception as exc:
257
+ logger.warning("Milvus fetch for graph chunks failed: %s", exc)
258
+ return []
259
+
260
+ # ── spaCy ─────────────────────────────────────────────────
261
+
262
+ def _get_nlp(self):
263
+ if self._nlp is None:
264
+ import spacy # type: ignore
265
+ try:
266
+ self._nlp = spacy.load("en_core_web_sm")
267
+ except OSError:
268
+ raise RuntimeError(
269
+ "Download spaCy model: python -m spacy download en_core_web_sm"
270
+ )
271
+ return self._nlp
retrieval/orchestrator.py CHANGED
@@ -30,6 +30,8 @@ from retrieval.bm25 import BM25Retriever
30
  from retrieval.dense import MilvusStore, RetrievedChunk
31
  from retrieval.embedder import Embedder
32
  from retrieval.fusion import CrossEncoderReranker, RRFFusion
 
 
33
  from retrieval.router import QueryRouter, RoutingDecision
34
 
35
  logger = logging.getLogger(__name__)
@@ -51,6 +53,7 @@ class MultiStrategyRetriever:
51
  embedder: Optional[Embedder] = None,
52
  store: Optional[MilvusStore] = None,
53
  bm25: Optional[BM25Retriever] = None,
 
54
  router: Optional[QueryRouter] = None,
55
  fuser: Optional[RRFFusion] = None,
56
  reranker: Optional[CrossEncoderReranker] = None,
@@ -58,6 +61,10 @@ class MultiStrategyRetriever:
58
  self._embedder = embedder or Embedder()
59
  self._dense = store or MilvusStore(embedder=self._embedder)
60
  self._bm25 = bm25 or BM25Retriever()
 
 
 
 
61
  self._router = router or QueryRouter()
62
  self._fuser = fuser or RRFFusion()
63
  self._reranker = reranker or CrossEncoderReranker()
@@ -112,6 +119,17 @@ class MultiStrategyRetriever:
112
  Dense indexing is handled separately by MilvusStore.
113
  """
114
  return self._bm25.add_chunks(chunks)
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  # ── Private: parallel retrieval ───────────────────────────
117
 
@@ -128,7 +146,7 @@ class MultiStrategyRetriever:
128
  retriever_map = {
129
  "dense": lambda q, k: self._dense.search(q, top_k=k),
130
  "bm25": lambda q, k: self._bm25.search(q, top_k=k),
131
- # "graph" will be registered here in Phase 3
132
  }
133
 
134
  results: dict[str, list[RetrievedChunk]] = {}
 
30
  from retrieval.dense import MilvusStore, RetrievedChunk
31
  from retrieval.embedder import Embedder
32
  from retrieval.fusion import CrossEncoderReranker, RRFFusion
33
+ from retrieval.graph_builder import KnowledgeGraphBuilder
34
+ from retrieval.graph_retriever import GraphRetriever
35
  from retrieval.router import QueryRouter, RoutingDecision
36
 
37
  logger = logging.getLogger(__name__)
 
53
  embedder: Optional[Embedder] = None,
54
  store: Optional[MilvusStore] = None,
55
  bm25: Optional[BM25Retriever] = None,
56
+ graph_builder: Optional[KnowledgeGraphBuilder] = None,
57
  router: Optional[QueryRouter] = None,
58
  fuser: Optional[RRFFusion] = None,
59
  reranker: Optional[CrossEncoderReranker] = None,
 
61
  self._embedder = embedder or Embedder()
62
  self._dense = store or MilvusStore(embedder=self._embedder)
63
  self._bm25 = bm25 or BM25Retriever()
64
+ self._graph_builder = graph_builder or KnowledgeGraphBuilder()
65
+ self._graph = GraphRetriever(
66
+ graph_builder=self._graph_builder, store=self._dense
67
+ )
68
  self._router = router or QueryRouter()
69
  self._fuser = fuser or RRFFusion()
70
  self._reranker = reranker or CrossEncoderReranker()
 
119
  Dense indexing is handled separately by MilvusStore.
120
  """
121
  return self._bm25.add_chunks(chunks)
122
+
123
+ def build_graph(self, chunks: list) -> dict:
124
+ """
125
+ Extract entities + relations from chunks and update the knowledge graph.
126
+ Call from ingestion pipeline after dense + BM25 indexing.
127
+ """
128
+ return self._graph_builder.process_chunks(chunks)
129
+
130
+ @property
131
+ def graph_builder(self) -> KnowledgeGraphBuilder:
132
+ return self._graph_builder
133
 
134
  # ── Private: parallel retrieval ───────────────────────────
135
 
 
146
  retriever_map = {
147
  "dense": lambda q, k: self._dense.search(q, top_k=k),
148
  "bm25": lambda q, k: self._bm25.search(q, top_k=k),
149
+ "graph": lambda q, k: self._graph.search(q, top_k=k),
150
  }
151
 
152
  results: dict[str, list[RetrievedChunk]] = {}
retrieval/relation_extractors.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cortex RAG β€” Relation Extractors
3
+
4
+ Strategy pattern: both extractors share the same interface.
5
+ Switch between them with GRAPH_EXTRACTOR=rebel|llm in .env.
6
+
7
+ RelationExtractor (abstract)
8
+ β”œβ”€β”€ REBELExtractor β€” local model, no API calls, Wikidata predicates
9
+ └── LLMExtractor β€” Mistral/LLM, free-form predicates, rate-limited
10
+
11
+ KnowledgeGraphBuilder accepts either via dependency injection, or
12
+ auto-selects based on config.get_settings().graph_extractor.
13
+
14
+ Adding a new extractor in the future:
15
+ 1. Subclass RelationExtractor
16
+ 2. Implement extract(chunk) β†’ list[Triple]
17
+ 3. Register the name in _EXTRACTOR_REGISTRY at the bottom of this file
18
+ """
19
+ from __future__ import annotations
20
+
21
+ import json
22
+ import logging
23
+ import re
24
+ import time
25
+ from abc import ABC, abstractmethod
26
+ from dataclasses import dataclass
27
+ from typing import Optional
28
+
29
+ from config import get_settings
30
+ from ingestion.chunker import Chunk
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ # ── Shared dataclass ───────────────────────────────────────────
36
+
37
+ @dataclass
38
+ class Triple:
39
+ subject: str
40
+ predicate: str
41
+ object: str
42
+ chunk_id: str
43
+ source: str
44
+ extractor: str = "unknown" # tracks which extractor produced this triple
45
+
46
+
47
+ # ── Abstract base ──────────────────────────────────────────────
48
+
49
+ class RelationExtractor(ABC):
50
+ """
51
+ Common interface for all relation extraction strategies.
52
+ Subclasses must implement extract() only.
53
+ """
54
+
55
+ @abstractmethod
56
+ def extract(self, chunk: Chunk) -> list[Triple]:
57
+ """
58
+ Extract (subject, predicate, object) triples from a single chunk.
59
+ Must never raise β€” return [] on any failure.
60
+ """
61
+ ...
62
+
63
+ @property
64
+ @abstractmethod
65
+ def name(self) -> str:
66
+ """Short identifier used in logging and triple.extractor field."""
67
+ ...
68
+
69
+ def extract_batch(self, chunks: list[Chunk]) -> dict[str, list[Triple]]:
70
+ """
71
+ Extract triples from a list of chunks.
72
+ Default: calls extract() sequentially.
73
+ Subclasses can override for true batching (e.g. REBEL).
74
+
75
+ Returns: {chunk_id: [Triple, ...]}
76
+ """
77
+ return {chunk.chunk_id: self.extract(chunk) for chunk in chunks}
78
+
79
+
80
+ # ── REBEL extractor ────────────────────────────────────────────
81
+
82
+ # REBEL relation types that map cleanly to RAG-useful edges.
83
+ # The full Wikidata set has 220 types; we keep the ~40 most useful.
84
+ _REBEL_KEEP_RELATIONS = {
85
+ "author", "developer", "creator", "founded by", "owned by",
86
+ "instance of", "subclass of", "part of", "has part",
87
+ "country", "country of origin", "located in", "headquarters location",
88
+ "employer", "member of", "affiliation", "educated at",
89
+ "award received", "occupation", "field of work", "notable work",
90
+ "based on", "followed by", "follows", "influenced by", "has edition",
91
+ "product or material produced", "used by", "manufacturer",
92
+ "publication date", "academic degree", "applies to jurisdiction",
93
+ "published in", "platform", "programming language", "license",
94
+ }
95
+
96
+
97
+ class REBELExtractor(RelationExtractor):
98
+ """
99
+ Local relation extraction using REBEL (Babelscape/rebel-large).
100
+
101
+ Model facts:
102
+ - 406M params (BART-large fine-tuned on Wikipedia + Wikidata)
103
+ - Input: raw text sentence(s)
104
+ - Output: decoded triplet string β†’ parsed into (head, type, tail)
105
+ - CPU inference: ~80–150ms per chunk on modern hardware
106
+ - No API calls, no rate limits, fully offline after first download
107
+
108
+ Batching:
109
+ REBEL's tokeniser handles variable-length batches natively.
110
+ extract_batch() sends all chunks in one forward pass, which is
111
+ significantly faster than calling extract() in a loop.
112
+ Max batch size is controlled by REBEL_BATCH_SIZE in config
113
+ (default 8 β€” safe for 8GB RAM; raise to 16–32 with more RAM).
114
+
115
+ Predicate normalisation:
116
+ REBEL outputs Wikidata relation labels (e.g. "country of origin").
117
+ We keep only relations in _REBEL_KEEP_RELATIONS (40 types) and
118
+ discard the rest β€” this prevents graph noise from obscure predicates
119
+ like "Wikimedia disambiguation page" or "image" polluting the graph.
120
+
121
+ Download:
122
+ First run downloads ~1.6GB to ~/.cache/huggingface/hub/.
123
+ Subsequent runs load from cache in ~3s.
124
+ """
125
+
126
+ _REBEL_MODEL = "Babelscape/rebel-large"
127
+ _MAX_INPUT_TOKENS = 256 # REBEL was trained on short passages
128
+ _MAX_OUTPUT_TOKENS = 512
129
+
130
+ def __init__(self) -> None:
131
+ self._tokenizer = None
132
+ self._model = None
133
+
134
+ @property
135
+ def name(self) -> str:
136
+ return "rebel"
137
+
138
+ # ── Public ──────────────────────────────────────────────────
139
+
140
+ def extract(self, chunk: Chunk) -> list[Triple]:
141
+ """Single-chunk extraction (sequential). Prefer extract_batch for speed."""
142
+ results = self.extract_batch([chunk])
143
+ return results.get(chunk.chunk_id, [])
144
+
145
+ def extract_batch(self, chunks: list[Chunk]) -> dict[str, list[Triple]]:
146
+ """
147
+ True batched extraction. All chunks processed in a single model call.
148
+ Falls back to sequential on memory errors.
149
+ """
150
+ if not chunks:
151
+ return {}
152
+
153
+ tok, model = self._load()
154
+ cfg = get_settings()
155
+ batch_size = getattr(cfg, "rebel_batch_size", 8)
156
+
157
+ # Chunk text is truncated to avoid exceeding REBEL's context window
158
+ texts = [c.text[:1200] for c in chunks]
159
+ all_triples: dict[str, list[Triple]] = {c.chunk_id: [] for c in chunks}
160
+
161
+ for batch_start in range(0, len(chunks), batch_size):
162
+ batch_chunks = chunks[batch_start : batch_start + batch_size]
163
+ batch_texts = texts[batch_start : batch_start + batch_size]
164
+
165
+ try:
166
+ inputs = tok(
167
+ batch_texts,
168
+ max_length=self._MAX_INPUT_TOKENS,
169
+ padding=True,
170
+ truncation=True,
171
+ return_tensors="pt",
172
+ )
173
+ generated = model.generate(
174
+ **inputs,
175
+ max_length=self._MAX_OUTPUT_TOKENS,
176
+ num_beams=3,
177
+ early_stopping=True,
178
+ )
179
+ decoded = tok.batch_decode(generated, skip_special_tokens=False)
180
+
181
+ for chunk, raw_output in zip(batch_chunks, decoded):
182
+ triples = self._parse_rebel_output(raw_output, chunk)
183
+ all_triples[chunk.chunk_id] = triples
184
+
185
+ except Exception as exc:
186
+ logger.warning("REBEL batch %d failed: %s", batch_start, exc)
187
+ # Mark as empty rather than crashing the whole ingestion
188
+ for chunk in batch_chunks:
189
+ all_triples[chunk.chunk_id] = []
190
+
191
+ return all_triples
192
+
193
+ # ── REBEL output parser ────────────────────────────────────
194
+
195
+ def _parse_rebel_output(self, decoded: str, chunk: Chunk) -> list[Triple]:
196
+ """
197
+ Parse REBEL's special-token output format.
198
+
199
+ REBEL outputs a string like:
200
+ <triplet> Vaswani <subj> Attention Is All You Need <obj> author
201
+ <triplet> Transformer <subj> NLP <obj> field of work
202
+
203
+ We extract each triplet, filter to keep relations, normalise,
204
+ and return Triple dataclasses.
205
+ """
206
+ triples: list[Triple] = []
207
+
208
+ # Split on <triplet> delimiter
209
+ raw_triplets = decoded.split("<triplet>")
210
+
211
+ for raw in raw_triplets:
212
+ raw = raw.strip()
213
+ if not raw or "<subj>" not in raw or "<obj>" not in raw:
214
+ continue
215
+
216
+ try:
217
+ # Format: "SUBJECT <subj> OBJECT <obj> RELATION"
218
+ subj_split = raw.split("<subj>")
219
+ subject = subj_split[0].strip()
220
+
221
+ obj_rel = subj_split[1].split("<obj>")
222
+ obj = obj_rel[0].strip()
223
+ relation = obj_rel[1].strip()
224
+
225
+ # Clean up any residual special tokens
226
+ for tok_str in ["</s>", "<s>", "<pad>"]:
227
+ relation = relation.replace(tok_str, "").strip()
228
+ subject = subject.replace(tok_str, "").strip()
229
+ obj = obj.replace(tok_str, "").strip()
230
+
231
+ if not subject or not obj or not relation:
232
+ continue
233
+
234
+ # Filter to useful relation types only
235
+ if relation.lower() not in _REBEL_KEEP_RELATIONS:
236
+ continue
237
+
238
+ triples.append(Triple(
239
+ subject=subject.title(),
240
+ predicate=relation.lower(),
241
+ object=obj.title(),
242
+ chunk_id=chunk.chunk_id,
243
+ source=chunk.source,
244
+ extractor=self.name,
245
+ ))
246
+
247
+ except (IndexError, AttributeError):
248
+ continue
249
+
250
+ return triples[:8] # cap per chunk
251
+
252
+ # ── Model loading ──────────────────────────────────────────
253
+
254
+ def _load(self):
255
+ if self._tokenizer is None or self._model is None:
256
+ try:
257
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # type: ignore
258
+ except ImportError as exc:
259
+ raise RuntimeError(
260
+ "Install transformers: pip install transformers"
261
+ ) from exc
262
+
263
+ logger.info("Loading REBEL model '%s' (first run downloads ~1.6GB)…", self._REBEL_MODEL)
264
+ t0 = time.perf_counter()
265
+ self._tokenizer = AutoTokenizer.from_pretrained(self._REBEL_MODEL)
266
+ self._model = AutoModelForSeq2SeqLM.from_pretrained(self._REBEL_MODEL)
267
+ self._model.eval() # inference mode β€” disables dropout
268
+ logger.info("REBEL loaded in %.1fs", time.perf_counter() - t0)
269
+
270
+ return self._tokenizer, self._model
271
+
272
+
273
+ # ── LLM extractor (original method, preserved) ─────────────────
274
+
275
+ _LLM_PROMPT = """\
276
+ Extract factual relationships from the passage below.
277
+ Return ONLY a JSON array of triples. Each triple is:
278
+ {{"subject": "...", "predicate": "...", "object": "..."}}
279
+
280
+ Rules:
281
+ - subject and object must be named entities (people, orgs, systems, concepts)
282
+ - predicate is a short verb phrase ("developed", "is based on", "introduced", "authored")
283
+ - Extract 0–5 triples maximum. If there are none, return []
284
+ - Return ONLY the JSON array, no explanation, no markdown
285
+
286
+ Passage:
287
+ {text}
288
+ """
289
+
290
+
291
+ class LLMExtractor(RelationExtractor):
292
+ """
293
+ Relation extraction via Mistral LLM (the original Phase 3 method).
294
+
295
+ Produces free-form, human-readable predicates ("introduced the concept of",
296
+ "co-authored with") rather than the fixed Wikidata vocabulary that REBEL uses.
297
+
298
+ Use this when:
299
+ - You want rich, domain-specific predicate labels
300
+ - Your corpus is small enough that rate limits aren't a problem
301
+ - You want to fine-tune the extraction prompt for your specific domain
302
+
303
+ Rate limiting:
304
+ Set MISTRAL_RELATION_RPM in .env to cap requests-per-minute.
305
+ Default is 0 (no cap). Mistral free tier allows ~30 RPM.
306
+ """
307
+
308
+ def __init__(self) -> None:
309
+ self._llm = None
310
+
311
+ @property
312
+ def name(self) -> str:
313
+ return "llm"
314
+
315
+ def extract(self, chunk: Chunk) -> list[Triple]:
316
+ try:
317
+ client = self._get_llm()
318
+ cfg = get_settings()
319
+
320
+ if cfg.llm_server == "ollama":
321
+ response = client.chat.complete(
322
+ model=cfg.ollama_model,
323
+ messages=[{
324
+ "role": "user",
325
+ "content": _LLM_PROMPT.format(text=chunk.text[:2000]),
326
+ }],
327
+ )
328
+ else:
329
+ response = client.chat.complete(
330
+ model=cfg.mistral_model,
331
+ messages=[{
332
+ "role": "user",
333
+ "content": _LLM_PROMPT.format(text=chunk.text[:2000]),
334
+ }],
335
+ temperature=0.0,
336
+ max_tokens=512,
337
+ )
338
+ raw = response.choices[0].message.content or "[]"
339
+ return self._parse(raw, chunk)
340
+
341
+ except Exception as exc:
342
+ logger.debug("LLM extractor failed for chunk %s: %s", chunk.chunk_id, exc)
343
+ return []
344
+
345
+ def _parse(self, raw: str, chunk: Chunk) -> list[Triple]:
346
+ raw = raw.strip()
347
+ if raw.startswith("```"):
348
+ raw = re.sub(r"^```[a-z]*\n?", "", raw)
349
+ raw = re.sub(r"\n?```$", "", raw)
350
+ try:
351
+ items = json.loads(raw)
352
+ except json.JSONDecodeError:
353
+ return []
354
+
355
+ triples: list[Triple] = []
356
+ for item in items[:5]:
357
+ if not isinstance(item, dict):
358
+ continue
359
+ s = str(item.get("subject", "")).strip()
360
+ p = str(item.get("predicate", "")).strip()
361
+ o = str(item.get("object", "")).strip()
362
+ if s and p and o:
363
+ triples.append(Triple(
364
+ subject=s.title(),
365
+ predicate=p.lower(),
366
+ object=o.title(),
367
+ chunk_id=chunk.chunk_id,
368
+ source=chunk.source,
369
+ extractor=self.name,
370
+ ))
371
+ return triples
372
+
373
+ def _get_llm(self):
374
+ if self._llm is None:
375
+ cfg = get_settings()
376
+ llm_server = cfg.llm_server
377
+
378
+ if llm_server == "ollama":
379
+ try:
380
+ from ollama import Client as ollama_client # type: ignore
381
+ except ImportError as exc:
382
+ raise RuntimeError(
383
+ "Install ollama client: pip install ollama"
384
+ ) from exc
385
+ self._llm = ollama_client(host=cfg.ollama_host)
386
+ else:
387
+ if not cfg.mistral_api_key:
388
+ raise RuntimeError("MISTRAL_API_KEY not set")
389
+ from mistralai.client import Mistral # type: ignore
390
+ self._llm = Mistral(api_key=cfg.mistral_api_key)
391
+ return self._llm
392
+
393
+ # ── Entity density filter (Option 4) ──────────────────────────
394
+
395
+ class EntityDensityFilter(RelationExtractor):
396
+ """
397
+ Decorator that wraps any extractor and skips low-entity-density chunks.
398
+
399
+ Rationale
400
+ ─────────
401
+ Chunks with 0–1 named entities rarely yield useful triples β€” a
402
+ paragraph of methodology boilerplate has no entities to link.
403
+ Scoring by entity density (entities per 100 tokens) and processing
404
+ only the top N% of chunks cuts extraction time by ~70% with
405
+ negligible graph quality loss.
406
+
407
+ How density is computed
408
+ ───────────────────────
409
+ density = (spaCy NER entity count) / (token count / 100)
410
+
411
+ This normalises for chunk length β€” a 50-token chunk with 3 entities
412
+ scores higher than a 500-token chunk with the same 3 entities.
413
+
414
+ Usage
415
+ ─────
416
+ # Wrap REBEL, keep top 30% of chunks (default):
417
+ extractor = EntityDensityFilter(REBELExtractor())
418
+
419
+ # Wrap LLM, keep top 20%, only chunks with β‰₯2 entities:
420
+ extractor = EntityDensityFilter(
421
+ LLMExtractor(),
422
+ top_fraction=0.20,
423
+ min_entity_count=2,
424
+ )
425
+
426
+ # Via config (wraps whatever GRAPH_EXTRACTOR is set to):
427
+ GRAPH_EXTRACTOR=rebel-filtered # rebel + density filter
428
+ GRAPH_EXTRACTOR=llm-filtered # llm + density filter
429
+ """
430
+
431
+ def __init__(
432
+ self,
433
+ inner: RelationExtractor,
434
+ top_fraction: Optional[float] = None,
435
+ min_entity_count: Optional[int] = None,
436
+ ) -> None:
437
+ cfg = get_settings()
438
+ self._inner = inner
439
+ # top_fraction: process only the top X% most entity-dense chunks
440
+ self._top_fraction = top_fraction or getattr(cfg, "density_top_fraction", 0.30)
441
+ # min_entity_count: hard floor β€” never extract from chunks below this
442
+ self._min_entity_count = min_entity_count or getattr(cfg, "density_min_entities", 2)
443
+ self._nlp = None
444
+
445
+ @property
446
+ def name(self) -> str:
447
+ return f"{self._inner.name}-filtered"
448
+
449
+ # ── Public ──────────────────────────────────────────────────
450
+
451
+ def extract(self, chunk: Chunk) -> list[Triple]:
452
+ """Single-chunk extraction with density pre-check."""
453
+ if not self._passes_density_check([chunk]):
454
+ logger.debug("Chunk %s skipped (low entity density)", chunk.chunk_id)
455
+ return []
456
+ return self._inner.extract(chunk)
457
+
458
+ def extract_batch(self, chunks: list[Chunk]) -> dict[str, list[Triple]]:
459
+ """
460
+ Filter chunks by density score, then delegate only the qualifying
461
+ subset to the inner extractor's batch method.
462
+
463
+ Steps:
464
+ 1. Score every chunk by entity density (fast β€” pure spaCy)
465
+ 2. Apply min_entity_count hard floor
466
+ 3. Keep top_fraction of remaining chunks by density score
467
+ 4. Pass filtered set to inner.extract_batch()
468
+ 5. Return merged result (skipped chunks β†’ empty list)
469
+ """
470
+ if not chunks:
471
+ return {}
472
+
473
+ # Score all chunks
474
+ scored = self._score_chunks(chunks) # list of (chunk, density, entity_count)
475
+
476
+ # Hard floor: drop chunks below minimum entity count
477
+ above_floor = [(c, d, n) for c, d, n in scored if n >= self._min_entity_count]
478
+
479
+ # Top-fraction cut: sort by density desc, keep top N%
480
+ above_floor.sort(key=lambda x: x[1], reverse=True)
481
+ cutoff = max(1, int(len(above_floor) * self._top_fraction))
482
+ selected = [c for c, _, _ in above_floor[:cutoff]]
483
+ skipped = len(chunks) - len(selected)
484
+
485
+ if skipped:
486
+ logger.info(
487
+ "Density filter: %d/%d chunks selected (top %.0f%%, min_entities=%d)",
488
+ len(selected), len(chunks),
489
+ self._top_fraction * 100, self._min_entity_count,
490
+ )
491
+
492
+ # Delegate to inner extractor
493
+ if not selected:
494
+ return {c.chunk_id: [] for c in chunks}
495
+
496
+ inner_results = self._inner.extract_batch(selected)
497
+
498
+ # Merge: unselected chunks get empty lists
499
+ selected_ids = {c.chunk_id for c in selected}
500
+ return {
501
+ c.chunk_id: inner_results.get(c.chunk_id, []) if c.chunk_id in selected_ids else []
502
+ for c in chunks
503
+ }
504
+
505
+ # ── Density scoring ────────────────────────────────────────
506
+
507
+ def _score_chunks(
508
+ self, chunks: list[Chunk]
509
+ ) -> list[tuple[Chunk, float, int]]:
510
+ """
511
+ Returns list of (chunk, density_score, entity_count).
512
+ density_score = entities per 100 tokens (approx).
513
+ """
514
+ nlp = self._get_nlp()
515
+ results = []
516
+ for chunk in chunks:
517
+ doc = nlp(chunk.text[:5000])
518
+ entity_count = len([e for e in doc.ents if len(e.text.strip()) > 1])
519
+ token_count = max(len(doc), 1)
520
+ density = (entity_count / token_count) * 100
521
+ results.append((chunk, density, entity_count))
522
+ return results
523
+
524
+ def _passes_density_check(self, chunks: list[Chunk]) -> bool:
525
+ """Quick single-chunk density check for extract()."""
526
+ if not chunks:
527
+ return False
528
+ _, _, entity_count = self._score_chunks(chunks)[0]
529
+ return entity_count >= self._min_entity_count
530
+
531
+ # ── spaCy ──────────────────────────────────────────────────
532
+
533
+ def _get_nlp(self):
534
+ if self._nlp is None:
535
+ import spacy # type: ignore
536
+ try:
537
+ self._nlp = spacy.load("en_core_web_sm")
538
+ except OSError:
539
+ raise RuntimeError("Run: python -m spacy download en_core_web_sm")
540
+ return self._nlp
541
+
542
+ # ── Registry + factory ─────────────────────────────────────────
543
+
544
+ _EXTRACTOR_REGISTRY: dict[str, type[RelationExtractor]] = {
545
+ "rebel": REBELExtractor,
546
+ "llm": LLMExtractor,
547
+ # Density-filtered variants are constructed specially β€” see build_extractor()
548
+ }
549
+
550
+ # Names that trigger density-filter wrapping
551
+ _FILTERED_VARIANTS = {
552
+ "rebel-filtered": "rebel",
553
+ "llm-filtered": "llm",
554
+ }
555
+
556
+
557
+ def build_extractor(name: Optional[str] = None) -> RelationExtractor:
558
+ """
559
+ Available values for GRAPH_EXTRACTOR:
560
+ "rebel" β€” REBEL local model, no API calls (default)
561
+ "llm" β€” Groq LLM, free-form predicates
562
+ "rebel-filtered" β€” REBEL + entity density pre-filter (option 4)
563
+ "llm-filtered" β€” LLM + entity density pre-filter (option 4)
564
+
565
+ Explicit usage in code:
566
+ extractor = build_extractor("rebel-filtered")
567
+
568
+ # Or compose manually for full control:
569
+ extractor = EntityDensityFilter(
570
+ REBELExtractor(),
571
+ top_fraction=0.25,
572
+ min_entity_count=3,
573
+ )
574
+ """
575
+ cfg = get_settings()
576
+ extractor_name = (name or getattr(cfg, "graph_extractor", "rebel")).lower()
577
+
578
+ # Density-filtered variant: build inner extractor then wrap it
579
+ if extractor_name in _FILTERED_VARIANTS:
580
+ inner_name = _FILTERED_VARIANTS[extractor_name]
581
+ inner_cls = _EXTRACTOR_REGISTRY[inner_name]
582
+ inner = inner_cls()
583
+ logger.info(
584
+ "Using relation extractor: %s (inner=%s, top_fraction=%.0f%%, min_entities=%d)",
585
+ extractor_name, inner_name,
586
+ getattr(cfg, "density_top_fraction", 0.30) * 100,
587
+ getattr(cfg, "density_min_entities", 2),
588
+ )
589
+ return EntityDensityFilter(inner)
590
+
591
+ # Plain extractor
592
+ cls = _EXTRACTOR_REGISTRY.get(extractor_name)
593
+ if cls is None:
594
+ available = list(_EXTRACTOR_REGISTRY.keys()) + list(_FILTERED_VARIANTS.keys())
595
+ raise ValueError(
596
+ f"Unknown extractor '{extractor_name}'. "
597
+ f"Available: {available}. "
598
+ f"Set GRAPH_EXTRACTOR in .env to one of these."
599
+ )
600
+
601
+ logger.info("Using relation extractor: %s", extractor_name)
602
+ return cls()
ui/app.py CHANGED
@@ -12,12 +12,18 @@ import json
12
  import time
13
  from pathlib import Path
14
  from typing import Optional
 
 
 
 
15
 
16
  import requests
17
  import streamlit as st
18
 
19
  # ── Config ────────────────────────────────────────────────────
20
- API_BASE = "http://localhost:8000"
 
 
21
 
22
  st.set_page_config(
23
  page_title="Cortex RAG",
@@ -66,11 +72,24 @@ def _render_source_cards_raw(chunks: list[dict]):
66
  title = chunk.get("title", "Unknown")
67
  source = Path(chunk.get("source", "")).name
68
  snippet = chunk.get("text_snippet", "")[:160]
 
 
 
 
 
 
 
 
 
 
 
 
69
  st.markdown(f"""
70
  <div class="source-card">
71
  <strong>[{i+1}] {title}</strong>
72
  <span class="score-badge" style="float:right">{score_pct}%</span><br/>
73
- <small style="color:#6b7280">{source}</small>
 
74
  <div class="chunk-snippet">{snippet}…</div>
75
  </div>""", unsafe_allow_html=True)
76
 
@@ -101,7 +120,7 @@ st.markdown(
101
  )
102
  st.divider()
103
 
104
- tab_ask, tab_ingest, tab_system = st.tabs(["πŸ” Ask", "πŸ“₯ Ingest", "🩺 System"])
105
 
106
 
107
  # ─────────────────────────────────────────────────────────────
@@ -173,6 +192,19 @@ with tab_ask:
173
  sources_placeholder.markdown(payload.get("text", ""))
174
  status_placeholder.empty()
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  elif event_type == "done":
177
  answer_placeholder.markdown(answer_text)
178
  status_placeholder.empty()
@@ -262,7 +294,129 @@ with tab_ingest:
262
 
263
 
264
  # ─────────────────────────────────────────────────────────────
265
- # TAB 3 β€” SYSTEM HEALTH
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  # ─────────────────────────────────────────────────────────────
267
  with tab_system:
268
  st.subheader("System health")
@@ -299,5 +453,13 @@ with tab_system:
299
  st.metric("Chunks indexed", stats.get("entity_count", "β€”"))
300
 
301
  st.divider()
 
 
 
 
 
 
 
 
302
  st.markdown("**Raw health response**")
303
  st.json(health)
 
12
  import time
13
  from pathlib import Path
14
  from typing import Optional
15
+ import sys
16
+
17
+ sys.path.append(str(Path(__file__).resolve().parent.parent))
18
+ from config import get_settings
19
 
20
  import requests
21
  import streamlit as st
22
 
23
  # ── Config ────────────────────────────────────────────────────
24
+ cfg = get_settings()
25
+ API_BASE = f"http://{cfg.api_host}:{cfg.api_port}"
26
+ REDIS_URL = cfg.redis_url
27
 
28
  st.set_page_config(
29
  page_title="Cortex RAG",
 
72
  title = chunk.get("title", "Unknown")
73
  source = Path(chunk.get("source", "")).name
74
  snippet = chunk.get("text_snippet", "")[:160]
75
+ retriever = chunk.get("retriever", "dense")
76
+ retriever_colors = {
77
+ "dense": "#dbeafe:#1e40af",
78
+ "bm25": "#dcfce7:#166534",
79
+ "dense+bm25": "#f3e8ff:#6b21a8",
80
+ "bm25+dense": "#f3e8ff:#6b21a8",
81
+ "graph": "#fef9c3:#854d0e",
82
+ "web_search": "#fee2e2:#991b1b",
83
+ }
84
+ ret_style = retriever_colors.get(retriever, "#f3f4f6:#374151")
85
+ ret_bg, ret_fg = ret_style.split(":")
86
+
87
  st.markdown(f"""
88
  <div class="source-card">
89
  <strong>[{i+1}] {title}</strong>
90
  <span class="score-badge" style="float:right">{score_pct}%</span><br/>
91
+ <small style="color:#6b7280">{source}</small> &nbsp;
92
+ <span style="background:{ret_bg};color:{ret_fg};border-radius:4px;padding:1px 6px;font-size:0.72rem;font-weight:600">{retriever}</span>
93
  <div class="chunk-snippet">{snippet}…</div>
94
  </div>""", unsafe_allow_html=True)
95
 
 
120
  )
121
  st.divider()
122
 
123
+ tab_ask, tab_ingest, tab_eval, tab_system = st.tabs(["πŸ” Ask", "πŸ“₯ Ingest", "πŸ“Š Evaluation", "🩺 System"])
124
 
125
 
126
  # ─────────────────────────────────────────────────────────────
 
192
  sources_placeholder.markdown(payload.get("text", ""))
193
  status_placeholder.empty()
194
 
195
+ elif event_type == "crag_update":
196
+ grade = payload.get("grade", "")
197
+ rewritten = payload.get("rewritten_query")
198
+ web_used = payload.get("web_search_used", False)
199
+ reasoning = payload.get("reasoning", "")
200
+ icon = {"POOR": "πŸ”„", "ABSENT": "🌐"}.get(grade, "ℹ️")
201
+ msg = f"{icon} **CRAG {grade}**: {reasoning[:100]}"
202
+ if rewritten:
203
+ msg += " \n\u21a9 Rewritten: *" + rewritten + "*"
204
+ if web_used:
205
+ msg += " \n\U0001f310 Web search fallback used"
206
+ status_placeholder.info(msg)
207
+
208
  elif event_type == "done":
209
  answer_placeholder.markdown(answer_text)
210
  status_placeholder.empty()
 
294
 
295
 
296
  # ─────────────────────────────────────────────────────────────
297
+ # TAB 3 β€” EVALUATION DASHBOARD
298
+ # ─────────────────────────────────────────────────────────────
299
+ with tab_eval:
300
+ st.subheader("RAG evaluation dashboard")
301
+ st.caption("Metrics update automatically after each query. RAGAS scores compute in the background (~5s after response).")
302
+
303
+ if st.button("πŸ”„ Refresh metrics"):
304
+ st.session_state.pop("metrics_data", None)
305
+
306
+ if "metrics_data" not in st.session_state:
307
+ try:
308
+ resp = requests.get(f"{API_BASE}/metrics?limit=200&days=14", timeout=5)
309
+ resp.raise_for_status()
310
+ st.session_state.metrics_data = resp.json()
311
+ except Exception as exc:
312
+ st.session_state.metrics_data = {"error": str(exc)}
313
+
314
+ mdata = st.session_state.get("metrics_data", {})
315
+
316
+ if "error" in mdata:
317
+ st.error(f"Cannot reach API: {mdata['error']}")
318
+ else:
319
+ summary = mdata.get("summary", {})
320
+ cache = mdata.get("cache", {})
321
+
322
+ # ── Header KPI row ─────────────────────────────────────
323
+ k1, k2, k3, k4, k5, k6 = st.columns(6)
324
+ k1.metric("Total queries", summary.get("total_queries", 0))
325
+ k2.metric("Faithfulness", f"{summary.get('avg_faithfulness', 0):.2f}")
326
+ k3.metric("Answer relevancy", f"{summary.get('avg_answer_relevancy', 0):.2f}")
327
+ k4.metric("Context precision",f"{summary.get('avg_context_precision', 0):.2f}")
328
+ k5.metric("Avg latency", f"{summary.get('avg_latency_ms', 0):.0f} ms")
329
+ k6.metric("Cache hit rate", f"{cache.get('hit_rate', 0):.0%}" if cache.get('enabled') else "off")
330
+
331
+ st.divider()
332
+
333
+ # ── Metric timeseries ──────────────────────────────────
334
+ ts = mdata.get("timeseries", [])
335
+ if ts:
336
+ import pandas as pd
337
+ df_ts = pd.DataFrame(ts)
338
+ df_ts["hour"] = df_ts["hour_bucket"]
339
+ st.markdown("#### RAGAS metrics over time")
340
+ st.line_chart(
341
+ df_ts.set_index("hour")[["faithfulness", "answer_relevancy", "context_precision"]],
342
+ height=220,
343
+ )
344
+ else:
345
+ st.info("No evaluation data yet. Run some queries to populate the dashboard.")
346
+
347
+ st.divider()
348
+
349
+ col_left, col_right = st.columns(2, gap="large")
350
+
351
+ with col_left:
352
+ # ── CRAG grade distribution ────────────────────────
353
+ grade_dist = summary.get("crag_grade_dist", {})
354
+ if grade_dist:
355
+ import pandas as pd
356
+ st.markdown("#### CRAG grade distribution")
357
+ df_grades = pd.DataFrame(
358
+ list(grade_dist.items()), columns=["Grade", "Count"]
359
+ )
360
+ st.bar_chart(df_grades.set_index("Grade"), height=180)
361
+
362
+ # ── Strategy distribution ──────────────────────────
363
+ strat_dist = summary.get("strategy_dist", {})
364
+ if strat_dist:
365
+ import pandas as pd
366
+ st.markdown("#### Retrieval strategy mix")
367
+ rows = []
368
+ for strat_json, cnt in strat_dist.items():
369
+ try:
370
+ import json as _json
371
+ label = "+".join(_json.loads(strat_json)).upper()
372
+ except Exception:
373
+ label = strat_json
374
+ rows.append({"Strategy": label, "Count": cnt})
375
+ df_strat = pd.DataFrame(rows)
376
+ st.bar_chart(df_strat.set_index("Strategy"), height=180)
377
+
378
+ with col_right:
379
+ # ── Cache stats ────────────────────────────────────
380
+ st.markdown("#### Cache")
381
+ if cache.get("enabled"):
382
+ c1, c2 = st.columns(2)
383
+ c1.metric("Hits", cache.get("hits", 0))
384
+ c2.metric("Misses", cache.get("misses", 0))
385
+ st.caption(f"TTL: {cache.get('ttl_s', 0)//60} min")
386
+ if st.button("πŸ—‘οΈ Flush cache"):
387
+ try:
388
+ r = requests.post(f"{REDIS_URL}/cache/flush", timeout=5)
389
+ st.success(f"Flushed {r.json().get('deleted', 0)} entries.")
390
+ st.session_state.pop("metrics_data", None)
391
+ except Exception as e:
392
+ st.error(str(e))
393
+ else:
394
+ st.caption("Redis not connected. Start Redis to enable caching.")
395
+ st.code("docker run -d -p 6379:6379 redis:7-alpine", language="bash")
396
+
397
+ st.divider()
398
+
399
+ # ── Recent query log table ─────────────────────────────
400
+ recent = mdata.get("recent", [])
401
+ if recent:
402
+ import pandas as pd
403
+ st.markdown("#### Recent queries")
404
+ rows = []
405
+ for r in recent[:50]:
406
+ rows.append({
407
+ "Query": r.get("query", "")[:60],
408
+ "Intent": r.get("intent", ""),
409
+ "CRAG": r.get("crag_grade", ""),
410
+ "Faithful": f"{r['faithfulness']:.2f}" if r.get("faithfulness") else "β€”",
411
+ "Relevancy": f"{r['answer_relevancy']:.2f}" if r.get("answer_relevancy") else "β€”",
412
+ "Precision": f"{r['context_precision']:.2f}" if r.get("context_precision") else "β€”",
413
+ "Latency ms": f"{r.get('latency_ms', 0):.0f}",
414
+ })
415
+ st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
416
+
417
+
418
+ # ─────────────────────────────────────────────────────────────
419
+ # TAB 4 β€” SYSTEM HEALTH
420
  # ─────────────────────────────────────────────────────────────
421
  with tab_system:
422
  st.subheader("System health")
 
453
  st.metric("Chunks indexed", stats.get("entity_count", "β€”"))
454
 
455
  st.divider()
456
+ graph_stats = health.get("graph_stats", {})
457
+ if graph_stats:
458
+ col_d, col_e = st.columns(2)
459
+ with col_d:
460
+ st.metric("Graph nodes", graph_stats.get("nodes", "β€”"))
461
+ with col_e:
462
+ st.metric("Graph edges", graph_stats.get("edges", "β€”"))
463
+ st.divider()
464
  st.markdown("**Raw health response**")
465
  st.json(health)