Spaces:
Running
Running
Commit Β·
f0d100b
1
Parent(s): f6803e9
Add phase 3 & 4
Browse files- .env.example +35 -6
- api/main.py +95 -3
- api/schemas.py +6 -0
- config.py +37 -1
- evaluation/__init__.py +0 -0
- evaluation/ragas_eval.py +228 -0
- evaluation/store.py +276 -0
- generation/crag.py +402 -0
- ingestion/pipeline.py +14 -1
- retrieval/cache.py +241 -0
- retrieval/graph_builder.py +278 -0
- retrieval/graph_retriever.py +271 -0
- retrieval/orchestrator.py +19 -1
- retrieval/relation_extractors.py +602 -0
- ui/app.py +166 -4
.env.example
CHANGED
|
@@ -1,13 +1,16 @@
|
|
| 1 |
# ββ Copy this file to .env and fill in your values βββββββββββββ
|
| 2 |
# cp .env.example .env
|
| 3 |
|
| 4 |
-
# ββ
|
| 5 |
-
GROQ_API_KEY=
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# ββ Milvus (defaults work with docker-compose) βββββββββββββββββ
|
| 8 |
-
MILVUS_HOST=
|
| 9 |
-
MILVUS_PORT=
|
| 10 |
-
MILVUS_COLLECTION=
|
| 11 |
|
| 12 |
# ββ Embedding model ββββββββββββββββββββββββββββββββββββββββββββ
|
| 13 |
EMBED_MODEL_NAME=BAAI/bge-small-en-v1.5
|
|
@@ -23,9 +26,35 @@ RETRIEVAL_TOP_K=15
|
|
| 23 |
FINAL_TOP_K=5
|
| 24 |
|
| 25 |
# ββ LLM ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 26 |
-
GROQ_MODEL=
|
| 27 |
GROQ_TEMPERATURE=0.1
|
| 28 |
GROQ_MAX_TOKENS=1024
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# ββ Logging ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
LOG_LEVEL=INFO
|
|
|
|
| 1 |
# ββ Copy this file to .env and fill in your values βββββββββββββ
|
| 2 |
# cp .env.example .env
|
| 3 |
|
| 4 |
+
# ββ API KEYS βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5 |
+
GROQ_API_KEY=
|
| 6 |
+
NVIDIA_API_KEY=
|
| 7 |
+
MISTRAL_API_KEY=
|
| 8 |
+
TAVILY_API_KEY=
|
| 9 |
|
| 10 |
# ββ Milvus (defaults work with docker-compose) βββββββββββββββββ
|
| 11 |
+
MILVUS_HOST=
|
| 12 |
+
MILVUS_PORT=
|
| 13 |
+
MILVUS_COLLECTION=
|
| 14 |
|
| 15 |
# ββ Embedding model ββββββββββββββββββββββββββββββββββββββββββββ
|
| 16 |
EMBED_MODEL_NAME=BAAI/bge-small-en-v1.5
|
|
|
|
| 26 |
FINAL_TOP_K=5
|
| 27 |
|
| 28 |
# ββ LLM ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 29 |
+
GROQ_MODEL=openai/gpt-oss-120b
|
| 30 |
GROQ_TEMPERATURE=0.1
|
| 31 |
GROQ_MAX_TOKENS=1024
|
| 32 |
|
| 33 |
+
# ββ Knowledge graph ββββββββββββββββββββββββββββββββββββββββββββ
|
| 34 |
+
# rebel β local REBEL model, no API calls (~1.6GB download on first run)
|
| 35 |
+
# llm β Groq LLM, free-form predicates (rate-limited)
|
| 36 |
+
# rebel-filtered β REBEL + entity density pre-filter: skips ~70% of chunks
|
| 37 |
+
# llm-filtered β LLM + entity density pre-filter: drastically fewer API calls
|
| 38 |
+
GRAPH_EXTRACTOR=llm-filtered
|
| 39 |
+
REBEL_BATCH_SIZE=4 # lower to 4 if you hit OOM on CPU
|
| 40 |
+
|
| 41 |
+
# Density filter settings (only used when GRAPH_EXTRACTOR ends with -filtered)
|
| 42 |
+
DENSITY_TOP_FRACTION=0.30 # process top 30% most entity-rich chunks
|
| 43 |
+
DENSITY_MIN_ENTITIES=2 # hard floor: always skip chunks with fewer than N entities
|
| 44 |
+
|
| 45 |
+
# ββ RE LLM (LLM accessible via Mistral or Ollama) ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
+
LLM_SERVER=mistral # options: mistral, ollama
|
| 47 |
+
MISTRAL_MODEL=devstral-latest
|
| 48 |
+
OLLAMA_MODEL=llama3.2:3b
|
| 49 |
+
OLLAMA_HOST=
|
| 50 |
+
|
| 51 |
+
# ββ Redis cache ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 52 |
+
REDIS_URL=
|
| 53 |
+
CACHE_TTL_SECONDS=3600
|
| 54 |
+
|
| 55 |
+
# ββ Evaluation βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 56 |
+
EVAL_DB_PATH=
|
| 57 |
+
EVAL_ENABLED=true # set false to skip RAGAS LLM calls
|
| 58 |
+
|
| 59 |
# ββ Logging ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 60 |
LOG_LEVEL=INFO
|
api/main.py
CHANGED
|
@@ -39,6 +39,10 @@ from api.schemas import (
|
|
| 39 |
)
|
| 40 |
from config import get_settings
|
| 41 |
from generation.generator import Generator, GenerationRequest
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
from ingestion.pipeline import IngestionPipeline
|
| 43 |
from retrieval.dense import MilvusStore
|
| 44 |
from retrieval.embedder import Embedder
|
|
@@ -46,6 +50,7 @@ from retrieval.bm25 import BM25Retriever
|
|
| 46 |
from retrieval.orchestrator import MultiStrategyRetriever
|
| 47 |
|
| 48 |
logger = logging.getLogger(__name__)
|
|
|
|
| 49 |
|
| 50 |
# ββ Shared singletons ββββββββββββββββββββββββββββββββββββββββββ
|
| 51 |
# Created once on startup, shared across requests
|
|
@@ -54,6 +59,9 @@ _embedder: Embedder = None
|
|
| 54 |
_store: MilvusStore = None
|
| 55 |
_bm25: BM25Retriever = None
|
| 56 |
_retriever: MultiStrategyRetriever = None
|
|
|
|
|
|
|
|
|
|
| 57 |
_generator: Generator = None
|
| 58 |
_pipeline: IngestionPipeline = None
|
| 59 |
|
|
@@ -61,14 +69,19 @@ _pipeline: IngestionPipeline = None
|
|
| 61 |
@asynccontextmanager
|
| 62 |
async def lifespan(app: FastAPI):
|
| 63 |
"""Initialise shared resources on startup, clean up on shutdown."""
|
| 64 |
-
global _embedder, _store, _bm25, _retriever, _generator, _pipeline
|
| 65 |
logger.info("Cortex starting up...")
|
| 66 |
|
| 67 |
_embedder = Embedder()
|
| 68 |
_store = MilvusStore(embedder=_embedder)
|
| 69 |
_bm25 = BM25Retriever()
|
| 70 |
_retriever = MultiStrategyRetriever(embedder=_embedder, store=_store, bm25=_bm25)
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
_pipeline = IngestionPipeline(embedder=_embedder, store=_store, bm25=_bm25)
|
| 73 |
|
| 74 |
# Warm up: trigger model load immediately so first request is fast
|
|
@@ -126,11 +139,18 @@ async def health() -> HealthResponse:
|
|
| 126 |
|
| 127 |
embedder_status = "loaded" if _embedder and _embedder._model else "not_loaded"
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
return HealthResponse(
|
| 130 |
status="ok" if milvus_status == "ok" else "degraded",
|
| 131 |
milvus=milvus_status,
|
| 132 |
embedder=embedder_status,
|
| 133 |
collection_stats=collection_stats,
|
|
|
|
| 134 |
)
|
| 135 |
|
| 136 |
|
|
@@ -159,6 +179,25 @@ async def ingest(req: IngestRequest) -> IngestResponse:
|
|
| 159 |
|
| 160 |
return IngestResponse(**stats)
|
| 161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
@app.post("/query", response_model=QueryResponse, tags=["retrieval"])
|
| 164 |
async def query(req: QueryRequest) -> QueryResponse:
|
|
@@ -169,6 +208,9 @@ async def query(req: QueryRequest) -> QueryResponse:
|
|
| 169 |
cfg = get_settings()
|
| 170 |
k = req.top_k or cfg.retrieval_top_k
|
| 171 |
|
|
|
|
|
|
|
|
|
|
| 172 |
try:
|
| 173 |
retrieval = _retriever.retrieve(req.query, top_k_candidates=k, final_top_k=cfg.final_top_k)
|
| 174 |
except Exception as exc:
|
|
@@ -187,6 +229,14 @@ async def query(req: QueryRequest) -> QueryResponse:
|
|
| 187 |
|
| 188 |
final_chunks = retrieval.chunks
|
| 189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
try:
|
| 191 |
result = _generator.generate(
|
| 192 |
GenerationRequest(query=req.query, chunks=final_chunks)
|
|
@@ -195,6 +245,30 @@ async def query(req: QueryRequest) -> QueryResponse:
|
|
| 195 |
logger.exception("Generation error")
|
| 196 |
raise HTTPException(status_code=500, detail=f"Generation failed: {exc}")
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
return QueryResponse(
|
| 199 |
query=req.query,
|
| 200 |
answer=result.answer,
|
|
@@ -277,7 +351,25 @@ async def query_stream(req: QueryRequest):
|
|
| 277 |
yield _sse_event({"type": "done"})
|
| 278 |
return
|
| 279 |
|
| 280 |
-
# 3.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
gen_request = GenerationRequest(
|
| 282 |
query=req.query,
|
| 283 |
chunks=final_chunks,
|
|
|
|
| 39 |
)
|
| 40 |
from config import get_settings
|
| 41 |
from generation.generator import Generator, GenerationRequest
|
| 42 |
+
from generation.crag import CRAGGate
|
| 43 |
+
from evaluation.store import EvalStore, QueryLogEntry
|
| 44 |
+
from evaluation.ragas_eval import RAGASEvaluator, EvalInput
|
| 45 |
+
from retrieval.cache import CachedRetriever
|
| 46 |
from ingestion.pipeline import IngestionPipeline
|
| 47 |
from retrieval.dense import MilvusStore
|
| 48 |
from retrieval.embedder import Embedder
|
|
|
|
| 50 |
from retrieval.orchestrator import MultiStrategyRetriever
|
| 51 |
|
| 52 |
logger = logging.getLogger(__name__)
|
| 53 |
+
cfg = get_settings()
|
| 54 |
|
| 55 |
# ββ Shared singletons ββββββββββββββββββββββββββββββββββββββββββ
|
| 56 |
# Created once on startup, shared across requests
|
|
|
|
| 59 |
_store: MilvusStore = None
|
| 60 |
_bm25: BM25Retriever = None
|
| 61 |
_retriever: MultiStrategyRetriever = None
|
| 62 |
+
_crag: CRAGGate = None
|
| 63 |
+
_eval_store: EvalStore = None
|
| 64 |
+
_evaluator: RAGASEvaluator = None
|
| 65 |
_generator: Generator = None
|
| 66 |
_pipeline: IngestionPipeline = None
|
| 67 |
|
|
|
|
| 69 |
@asynccontextmanager
|
| 70 |
async def lifespan(app: FastAPI):
|
| 71 |
"""Initialise shared resources on startup, clean up on shutdown."""
|
| 72 |
+
global _embedder, _store, _bm25, _retriever, _crag, _generator, _pipeline, _eval_store, _evaluator
|
| 73 |
logger.info("Cortex starting up...")
|
| 74 |
|
| 75 |
_embedder = Embedder()
|
| 76 |
_store = MilvusStore(embedder=_embedder)
|
| 77 |
_bm25 = BM25Retriever()
|
| 78 |
_retriever = MultiStrategyRetriever(embedder=_embedder, store=_store, bm25=_bm25)
|
| 79 |
+
_crag = CRAGGate()
|
| 80 |
+
_eval_store = EvalStore(db_path=cfg.eval_db_path)
|
| 81 |
+
_evaluator = RAGASEvaluator(store=_eval_store)
|
| 82 |
+
_generator = Generator()
|
| 83 |
+
# Wrap retriever with Redis cache (degrades gracefully if Redis is absent)
|
| 84 |
+
_retriever = CachedRetriever(_retriever)
|
| 85 |
_pipeline = IngestionPipeline(embedder=_embedder, store=_store, bm25=_bm25)
|
| 86 |
|
| 87 |
# Warm up: trigger model load immediately so first request is fast
|
|
|
|
| 139 |
|
| 140 |
embedder_status = "loaded" if _embedder and _embedder._model else "not_loaded"
|
| 141 |
|
| 142 |
+
graph_stats = {}
|
| 143 |
+
try:
|
| 144 |
+
graph_stats = _retriever.graph_builder.stats()
|
| 145 |
+
except Exception:
|
| 146 |
+
pass
|
| 147 |
+
|
| 148 |
return HealthResponse(
|
| 149 |
status="ok" if milvus_status == "ok" else "degraded",
|
| 150 |
milvus=milvus_status,
|
| 151 |
embedder=embedder_status,
|
| 152 |
collection_stats=collection_stats,
|
| 153 |
+
graph_stats=graph_stats,
|
| 154 |
)
|
| 155 |
|
| 156 |
|
|
|
|
| 179 |
|
| 180 |
return IngestResponse(**stats)
|
| 181 |
|
| 182 |
+
@app.get("/metrics", tags=["evaluation"])
|
| 183 |
+
async def get_metrics(limit: int = 100, days: int = 7):
|
| 184 |
+
"""
|
| 185 |
+
Query performance metrics and RAGAS scores for the dashboard.
|
| 186 |
+
Returns summary stats, recent query logs, and hourly timeseries.
|
| 187 |
+
"""
|
| 188 |
+
return {
|
| 189 |
+
"summary": _eval_store.get_summary_stats(),
|
| 190 |
+
"recent": _eval_store.get_recent_queries(limit=limit),
|
| 191 |
+
"timeseries": _eval_store.get_metric_timeseries(days=days),
|
| 192 |
+
"cache": _retriever.cache_stats(),
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@app.post("/cache/flush", tags=["system"])
|
| 197 |
+
async def flush_cache():
|
| 198 |
+
"""Flush all Redis retrieval cache entries."""
|
| 199 |
+
deleted = _retriever.flush_all()
|
| 200 |
+
return {"deleted": deleted}
|
| 201 |
|
| 202 |
@app.post("/query", response_model=QueryResponse, tags=["retrieval"])
|
| 203 |
async def query(req: QueryRequest) -> QueryResponse:
|
|
|
|
| 208 |
cfg = get_settings()
|
| 209 |
k = req.top_k or cfg.retrieval_top_k
|
| 210 |
|
| 211 |
+
import time as _time
|
| 212 |
+
_t0 = _time.perf_counter()
|
| 213 |
+
|
| 214 |
try:
|
| 215 |
retrieval = _retriever.retrieve(req.query, top_k_candidates=k, final_top_k=cfg.final_top_k)
|
| 216 |
except Exception as exc:
|
|
|
|
| 229 |
|
| 230 |
final_chunks = retrieval.chunks
|
| 231 |
|
| 232 |
+
# CRAG gate: grade, rewrite if POOR, web-search fallback if ABSENT
|
| 233 |
+
crag_result = _crag.evaluate(
|
| 234 |
+
query=req.query,
|
| 235 |
+
chunks=final_chunks,
|
| 236 |
+
retriever_fn=lambda q: _retriever.retrieve(q).chunks,
|
| 237 |
+
)
|
| 238 |
+
final_chunks = crag_result.final_chunks
|
| 239 |
+
|
| 240 |
try:
|
| 241 |
result = _generator.generate(
|
| 242 |
GenerationRequest(query=req.query, chunks=final_chunks)
|
|
|
|
| 245 |
logger.exception("Generation error")
|
| 246 |
raise HTTPException(status_code=500, detail=f"Generation failed: {exc}")
|
| 247 |
|
| 248 |
+
latency_ms = (_time.perf_counter() - _t0) * 1000
|
| 249 |
+
|
| 250 |
+
log_id = _eval_store.log_query(QueryLogEntry(
|
| 251 |
+
query=req.query,
|
| 252 |
+
intent=retrieval.decision.intent.value,
|
| 253 |
+
strategies=retrieval.decision.strategies,
|
| 254 |
+
retriever_hits=retrieval.retriever_hits,
|
| 255 |
+
crag_grade=crag_result.grade.value,
|
| 256 |
+
crag_rewritten=bool(crag_result.rewritten_query),
|
| 257 |
+
web_search_used=crag_result.web_search_used,
|
| 258 |
+
num_chunks=len(final_chunks),
|
| 259 |
+
top_chunk_score=final_chunks[0].score if final_chunks else 0.0,
|
| 260 |
+
latency_ms=latency_ms,
|
| 261 |
+
model=result.model,
|
| 262 |
+
))
|
| 263 |
+
|
| 264 |
+
if cfg.eval_enabled:
|
| 265 |
+
_evaluator.evaluate_async(EvalInput(
|
| 266 |
+
query_log_id=log_id,
|
| 267 |
+
query=req.query,
|
| 268 |
+
answer=result.answer,
|
| 269 |
+
chunks=final_chunks,
|
| 270 |
+
))
|
| 271 |
+
|
| 272 |
return QueryResponse(
|
| 273 |
query=req.query,
|
| 274 |
answer=result.answer,
|
|
|
|
| 351 |
yield _sse_event({"type": "done"})
|
| 352 |
return
|
| 353 |
|
| 354 |
+
# 3. CRAG gate β grade, optionally rewrite + re-retrieve
|
| 355 |
+
crag_result = _crag.evaluate(
|
| 356 |
+
query=req.query,
|
| 357 |
+
chunks=final_chunks,
|
| 358 |
+
retriever_fn=lambda q: _retriever.retrieve(q).chunks,
|
| 359 |
+
)
|
| 360 |
+
final_chunks = crag_result.final_chunks
|
| 361 |
+
|
| 362 |
+
# Emit CRAG event if something interesting happened
|
| 363 |
+
if crag_result.grade.value != "GOOD" or crag_result.web_search_used:
|
| 364 |
+
yield _sse_event({
|
| 365 |
+
"type": "crag_update",
|
| 366 |
+
"grade": crag_result.grade.value,
|
| 367 |
+
"rewritten_query": crag_result.rewritten_query,
|
| 368 |
+
"web_search_used": crag_result.web_search_used,
|
| 369 |
+
"reasoning": crag_result.reasoning,
|
| 370 |
+
})
|
| 371 |
+
|
| 372 |
+
# 4. Stream answer tokens
|
| 373 |
gen_request = GenerationRequest(
|
| 374 |
query=req.query,
|
| 375 |
chunks=final_chunks,
|
api/schemas.py
CHANGED
|
@@ -45,6 +45,9 @@ class QueryResponse(BaseModel):
|
|
| 45 |
citations: list[CitationResponse]
|
| 46 |
retrieved_chunks: list[ChunkResponse]
|
| 47 |
routing: Optional[RoutingResponse] = None
|
|
|
|
|
|
|
|
|
|
| 48 |
model: str
|
| 49 |
usage: dict
|
| 50 |
|
|
@@ -60,6 +63,8 @@ class IngestResponse(BaseModel):
|
|
| 60 |
chunks_created: int
|
| 61 |
chunks_stored: int
|
| 62 |
bm25_indexed: int = 0
|
|
|
|
|
|
|
| 63 |
errors: list[dict] = []
|
| 64 |
|
| 65 |
|
|
@@ -68,3 +73,4 @@ class HealthResponse(BaseModel):
|
|
| 68 |
milvus: str
|
| 69 |
embedder: str
|
| 70 |
collection_stats: dict
|
|
|
|
|
|
| 45 |
citations: list[CitationResponse]
|
| 46 |
retrieved_chunks: list[ChunkResponse]
|
| 47 |
routing: Optional[RoutingResponse] = None
|
| 48 |
+
crag_grade: Optional[str] = None
|
| 49 |
+
crag_rewritten_query: Optional[str] = None
|
| 50 |
+
web_search_used: bool = False
|
| 51 |
model: str
|
| 52 |
usage: dict
|
| 53 |
|
|
|
|
| 63 |
chunks_created: int
|
| 64 |
chunks_stored: int
|
| 65 |
bm25_indexed: int = 0
|
| 66 |
+
graph_entities: int = 0
|
| 67 |
+
graph_triples: int = 0
|
| 68 |
errors: list[dict] = []
|
| 69 |
|
| 70 |
|
|
|
|
| 73 |
milvus: str
|
| 74 |
embedder: str
|
| 75 |
collection_stats: dict
|
| 76 |
+
graph_stats: dict = {}
|
config.py
CHANGED
|
@@ -40,12 +40,15 @@ class Settings(BaseSettings):
|
|
| 40 |
retrieval_top_k: int = 15 # candidates before reranking
|
| 41 |
final_top_k: int = 5 # chunks sent to LLM
|
| 42 |
|
| 43 |
-
# ββ LLM /
|
| 44 |
groq_api_key: str = os.getenv("GROQ_API_KEY", "") # must be set in .env for LLM classification to work
|
| 45 |
groq_model: str = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
|
| 46 |
groq_temperature: float = float(os.getenv("GROQ_TEMPERATURE", 0.1))
|
| 47 |
groq_max_tokens: int = int(os.getenv("GROQ_MAX_TOKENS", 1024))
|
| 48 |
groq_timeout: int = int(os.getenv("GROQ_TIMEOUT", 30)) # seconds before Groq client timeout
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
# ββ FastAPI ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 51 |
api_host: str = "0.0.0.0"
|
|
@@ -56,6 +59,39 @@ class Settings(BaseSettings):
|
|
| 56 |
data_dir: str = "data/documents"
|
| 57 |
log_level: str = "INFO"
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
@lru_cache(maxsize=1)
|
| 61 |
def get_settings() -> Settings:
|
|
|
|
| 40 |
retrieval_top_k: int = 15 # candidates before reranking
|
| 41 |
final_top_k: int = 5 # chunks sent to LLM
|
| 42 |
|
| 43 |
+
# ββ LLM / TAVILY βββββββββββββββββββββββββββββββββββββββββββ
|
| 44 |
groq_api_key: str = os.getenv("GROQ_API_KEY", "") # must be set in .env for LLM classification to work
|
| 45 |
groq_model: str = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
|
| 46 |
groq_temperature: float = float(os.getenv("GROQ_TEMPERATURE", 0.1))
|
| 47 |
groq_max_tokens: int = int(os.getenv("GROQ_MAX_TOKENS", 1024))
|
| 48 |
groq_timeout: int = int(os.getenv("GROQ_TIMEOUT", 30)) # seconds before Groq client timeout
|
| 49 |
+
tavily_api_key: str = os.getenv("TAVILY_API_KEY", "")
|
| 50 |
+
mistral_api_key: str = os.getenv("MISTRAL_API_KEY", "")
|
| 51 |
+
mistral_model: str = os.getenv("MISTRAL_MODEL", "devstral-latest")
|
| 52 |
|
| 53 |
# ββ FastAPI ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
api_host: str = "0.0.0.0"
|
|
|
|
| 59 |
data_dir: str = "data/documents"
|
| 60 |
log_level: str = "INFO"
|
| 61 |
|
| 62 |
+
# ββ CRAG βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 63 |
+
crag_enabled: bool = True
|
| 64 |
+
crag_relevance_threshold: float = 0.5 # below this β POOR grade
|
| 65 |
+
|
| 66 |
+
# ββ Graph βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 67 |
+
graph_enabled: bool = True
|
| 68 |
+
graph_path: str = "data/knowledge_graph.json"
|
| 69 |
+
graph_max_hops: int = 2
|
| 70 |
+
# "rebel" β local REBEL model, no API calls (default)
|
| 71 |
+
# "llm" β Groq LLM, free-form predicates
|
| 72 |
+
# "rebel-filtered" β REBEL + entity density pre-filter (option 4)
|
| 73 |
+
# "llm-filtered" β LLM + entity density pre-filter (option 4)
|
| 74 |
+
graph_extractor: str = "llm-filtered"
|
| 75 |
+
rebel_batch_size: int = 4 # chunks per REBEL forward pass; lower if OOM
|
| 76 |
+
|
| 77 |
+
# ββ Density filter (used when graph_extractor ends with "-filtered") ββ
|
| 78 |
+
density_top_fraction: float = 0.30 # process top 30% most entity-dense chunks
|
| 79 |
+
density_min_entities: int = 2 # hard floor: skip chunks with fewer entities
|
| 80 |
+
|
| 81 |
+
# ββ Relation Ext LLM (LLM accessible via Mistral or Ollama) ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 82 |
+
llm_server: str = os.getenv("LLM_SERVER", "mistral") # "mistral" or "ollama"
|
| 83 |
+
ollama_model: str = os.getenv("OLLAMA_MODEL", "llama3.2:3b")
|
| 84 |
+
ollama_host: str = os.getenv("OLLAMA_HOST", "") # Ollama server URL
|
| 85 |
+
mistral_model: str = os.getenv("MISTRAL_MODEL", "devstral-latest")
|
| 86 |
+
|
| 87 |
+
# ββ Redis cache βββββββββββββββββββββββββββββββββββββββββββ
|
| 88 |
+
redis_url: str = "redis://localhost:6379"
|
| 89 |
+
cache_ttl_seconds: int = 3600 # 1 hour
|
| 90 |
+
|
| 91 |
+
# ββ Evaluation ββββββββββββββββββββββββββββββββββββββββββββ
|
| 92 |
+
eval_db_path: str = "data/cortex_eval.db"
|
| 93 |
+
eval_enabled: bool = True # set False to skip RAGAS calls entirely
|
| 94 |
+
|
| 95 |
|
| 96 |
@lru_cache(maxsize=1)
|
| 97 |
def get_settings() -> Settings:
|
evaluation/__init__.py
ADDED
|
File without changes
|
evaluation/ragas_eval.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Cortex RAG β RAGAS Evaluation Harness (Phase 4)
|
| 3 |
+
|
| 4 |
+
Why reference-free metrics?
|
| 5 |
+
ββββββββββββββββββββββββββββ
|
| 6 |
+
Classic RAG evaluation requires ground-truth answers (golden QA pairs).
|
| 7 |
+
We don't have those at runtime. RAGAS provides three metrics that need
|
| 8 |
+
only (question, answer, retrieved_contexts):
|
| 9 |
+
|
| 10 |
+
faithfulness β Does the answer make claims supported by the context?
|
| 11 |
+
Computed by asking an LLM to identify each claim in
|
| 12 |
+
the answer, then checking each claim against the context.
|
| 13 |
+
Score = supported_claims / total_claims.
|
| 14 |
+
|
| 15 |
+
answer_relevancy β Does the answer actually address the question?
|
| 16 |
+
Computed by generating N hypothetical questions from the
|
| 17 |
+
answer and measuring cosine similarity to the original
|
| 18 |
+
question. Low score = answer talks about something else.
|
| 19 |
+
|
| 20 |
+
context_precision β Are the retrieved chunks actually relevant to the query?
|
| 21 |
+
Computed by asking an LLM whether each chunk is useful
|
| 22 |
+
for answering the query. Score = relevant_chunks / total.
|
| 23 |
+
|
| 24 |
+
We also compute two lightweight custom metrics without any LLM calls:
|
| 25 |
+
|
| 26 |
+
context_utilisation β What fraction of the retrieved chunks are cited in the
|
| 27 |
+
answer? (Count [1], [2]... citation markers.) A low score
|
| 28 |
+
means the generator ignored most of what was retrieved.
|
| 29 |
+
|
| 30 |
+
mean_chunk_score β Average retrieval score (post-reranking) of the final
|
| 31 |
+
chunks. Tracks retrieval quality independently of answer
|
| 32 |
+
quality. Useful for spotting when CRAG rewrites help.
|
| 33 |
+
|
| 34 |
+
Running mode
|
| 35 |
+
ββββββββββββ
|
| 36 |
+
Evaluation is async β it runs in a background thread after the response
|
| 37 |
+
has been streamed to the user, so it never adds latency to the query path.
|
| 38 |
+
Results are written to the EvalStore (SQLite) and appear in the dashboard.
|
| 39 |
+
|
| 40 |
+
If RAGAS is not installed or the LLM call fails, only the two custom
|
| 41 |
+
metrics (context_utilisation, mean_chunk_score) are computed and stored.
|
| 42 |
+
This ensures the evaluation pipeline never blocks ingestion or queries.
|
| 43 |
+
"""
|
| 44 |
+
from __future__ import annotations
|
| 45 |
+
|
| 46 |
+
import logging
|
| 47 |
+
import re
|
| 48 |
+
import threading
|
| 49 |
+
from dataclasses import dataclass, field
|
| 50 |
+
from typing import Optional
|
| 51 |
+
|
| 52 |
+
from evaluation.store import EvalMetricEntry, EvalStore
|
| 53 |
+
from retrieval.dense import RetrievedChunk
|
| 54 |
+
|
| 55 |
+
logger = logging.getLogger(__name__)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class EvalInput:
|
| 60 |
+
"""Everything needed to evaluate one query-response pair."""
|
| 61 |
+
query_log_id: int
|
| 62 |
+
query: str
|
| 63 |
+
answer: str
|
| 64 |
+
chunks: list[RetrievedChunk] = field(default_factory=list)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclass
|
| 68 |
+
class EvalResult:
|
| 69 |
+
faithfulness: Optional[float] = None
|
| 70 |
+
answer_relevancy: Optional[float] = None
|
| 71 |
+
context_precision: Optional[float] = None
|
| 72 |
+
context_utilisation: Optional[float] = None
|
| 73 |
+
mean_chunk_score: Optional[float] = None
|
| 74 |
+
|
| 75 |
+
def as_store_entry(self, query_log_id: int) -> EvalMetricEntry:
|
| 76 |
+
return EvalMetricEntry(
|
| 77 |
+
query_log_id=query_log_id,
|
| 78 |
+
faithfulness=self.faithfulness,
|
| 79 |
+
answer_relevancy=self.answer_relevancy,
|
| 80 |
+
context_precision=self.context_precision,
|
| 81 |
+
context_utilisation=self.context_utilisation,
|
| 82 |
+
mean_chunk_score=self.mean_chunk_score,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class RAGASEvaluator:
|
| 87 |
+
"""
|
| 88 |
+
Computes RAGAS + custom metrics for a query-response pair.
|
| 89 |
+
|
| 90 |
+
Usage β fire-and-forget (non-blocking):
|
| 91 |
+
evaluator = RAGASEvaluator(store)
|
| 92 |
+
evaluator.evaluate_async(EvalInput(
|
| 93 |
+
query_log_id=log_id,
|
| 94 |
+
query="What is attention?",
|
| 95 |
+
answer="Attention is...",
|
| 96 |
+
chunks=final_chunks,
|
| 97 |
+
))
|
| 98 |
+
|
| 99 |
+
Usage β blocking (for testing):
|
| 100 |
+
result = evaluator.evaluate(eval_input)
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(self, store: Optional[EvalStore] = None) -> None:
|
| 104 |
+
self._store = store or EvalStore()
|
| 105 |
+
self._ragas_available = self._check_ragas()
|
| 106 |
+
|
| 107 |
+
# ββ Public API βββββββββββββββββββββββββββββββββββββββββββββ
|
| 108 |
+
|
| 109 |
+
def evaluate_async(self, inp: EvalInput) -> None:
|
| 110 |
+
"""
|
| 111 |
+
Run evaluation in a daemon thread. Returns immediately.
|
| 112 |
+
Results are written to EvalStore when complete.
|
| 113 |
+
"""
|
| 114 |
+
thread = threading.Thread(
|
| 115 |
+
target=self._run_and_store,
|
| 116 |
+
args=(inp,),
|
| 117 |
+
daemon=True,
|
| 118 |
+
name=f"ragas-eval-{inp.query_log_id}",
|
| 119 |
+
)
|
| 120 |
+
thread.start()
|
| 121 |
+
|
| 122 |
+
def evaluate(self, inp: EvalInput) -> EvalResult:
|
| 123 |
+
"""Blocking evaluation. Returns EvalResult."""
|
| 124 |
+
result = EvalResult()
|
| 125 |
+
|
| 126 |
+
# ββ Custom metrics (no LLM, always computed) ββββββββββ
|
| 127 |
+
result.context_utilisation = self._context_utilisation(inp.answer, inp.chunks)
|
| 128 |
+
result.mean_chunk_score = self._mean_chunk_score(inp.chunks)
|
| 129 |
+
|
| 130 |
+
# ββ RAGAS metrics (LLM-based, may be skipped) βββββββββ
|
| 131 |
+
if self._ragas_available and inp.chunks:
|
| 132 |
+
ragas_scores = self._run_ragas(inp)
|
| 133 |
+
result.faithfulness = ragas_scores.get("faithfulness")
|
| 134 |
+
result.answer_relevancy = ragas_scores.get("answer_relevancy")
|
| 135 |
+
result.context_precision = ragas_scores.get("context_precision")
|
| 136 |
+
else:
|
| 137 |
+
if not self._ragas_available:
|
| 138 |
+
logger.debug("RAGAS not installed β only custom metrics computed.")
|
| 139 |
+
|
| 140 |
+
return result
|
| 141 |
+
|
| 142 |
+
# ββ Private ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 143 |
+
|
| 144 |
+
def _run_and_store(self, inp: EvalInput) -> None:
|
| 145 |
+
try:
|
| 146 |
+
result = self.evaluate(inp)
|
| 147 |
+
self._store.log_metrics(result.as_store_entry(inp.query_log_id))
|
| 148 |
+
logger.debug(
|
| 149 |
+
"Eval stored for query %d: faith=%.2f rel=%.2f prec=%.2f util=%.2f",
|
| 150 |
+
inp.query_log_id,
|
| 151 |
+
result.faithfulness or 0,
|
| 152 |
+
result.answer_relevancy or 0,
|
| 153 |
+
result.context_precision or 0,
|
| 154 |
+
result.context_utilisation or 0,
|
| 155 |
+
)
|
| 156 |
+
except Exception as exc:
|
| 157 |
+
logger.warning("Eval failed for query %d: %s", inp.query_log_id, exc)
|
| 158 |
+
|
| 159 |
+
def _run_ragas(self, inp: EvalInput) -> dict:
|
| 160 |
+
"""
|
| 161 |
+
Call RAGAS library. Returns dict of metric_name β score.
|
| 162 |
+
Returns empty dict on any failure.
|
| 163 |
+
"""
|
| 164 |
+
try:
|
| 165 |
+
from datasets import Dataset # type: ignore
|
| 166 |
+
from ragas import evaluate as ragas_evaluate # type: ignore
|
| 167 |
+
from ragas.metrics import ( # type: ignore
|
| 168 |
+
answer_relevancy,
|
| 169 |
+
context_precision,
|
| 170 |
+
faithfulness,
|
| 171 |
+
)
|
| 172 |
+
from config import get_settings
|
| 173 |
+
cfg = get_settings()
|
| 174 |
+
|
| 175 |
+
# RAGAS expects a HuggingFace Dataset
|
| 176 |
+
data = {
|
| 177 |
+
"question": [inp.query],
|
| 178 |
+
"answer": [inp.answer],
|
| 179 |
+
"contexts": [[c.parent_text or c.text for c in inp.chunks]],
|
| 180 |
+
# reference not available at runtime β omit context_recall
|
| 181 |
+
}
|
| 182 |
+
dataset = Dataset.from_dict(data)
|
| 183 |
+
|
| 184 |
+
scores = ragas_evaluate(
|
| 185 |
+
dataset,
|
| 186 |
+
metrics=[faithfulness, answer_relevancy, context_precision],
|
| 187 |
+
raise_exceptions=False,
|
| 188 |
+
)
|
| 189 |
+
df = scores.to_pandas()
|
| 190 |
+
return {
|
| 191 |
+
"faithfulness": float(df["faithfulness"].iloc[0]) if "faithfulness" in df else None,
|
| 192 |
+
"answer_relevancy": float(df["answer_relevancy"].iloc[0]) if "answer_relevancy" in df else None,
|
| 193 |
+
"context_precision": float(df["context_precision"].iloc[0]) if "context_precision" in df else None,
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
except Exception as exc:
|
| 197 |
+
logger.warning("RAGAS evaluation failed: %s", exc)
|
| 198 |
+
return {}
|
| 199 |
+
|
| 200 |
+
# ββ Custom metrics (no LLM required) ββββββββββββββββββββββ
|
| 201 |
+
|
| 202 |
+
@staticmethod
|
| 203 |
+
def _context_utilisation(answer: str, chunks: list[RetrievedChunk]) -> float:
|
| 204 |
+
"""
|
| 205 |
+
Fraction of retrieved chunks cited in the answer.
|
| 206 |
+
Looks for inline [N] citation markers.
|
| 207 |
+
"""
|
| 208 |
+
if not chunks:
|
| 209 |
+
return 0.0
|
| 210 |
+
cited_indices = set(int(n) for n in re.findall(r"\[(\d+)\]", answer))
|
| 211 |
+
cited = sum(1 for i in range(1, len(chunks) + 1) if i in cited_indices)
|
| 212 |
+
return round(cited / len(chunks), 3)
|
| 213 |
+
|
| 214 |
+
@staticmethod
|
| 215 |
+
def _mean_chunk_score(chunks: list[RetrievedChunk]) -> float:
|
| 216 |
+
"""Average retrieval score of the final chunks."""
|
| 217 |
+
if not chunks:
|
| 218 |
+
return 0.0
|
| 219 |
+
return round(sum(c.score for c in chunks) / len(chunks), 3)
|
| 220 |
+
|
| 221 |
+
@staticmethod
|
| 222 |
+
def _check_ragas() -> bool:
|
| 223 |
+
try:
|
| 224 |
+
import ragas # type: ignore # noqa: F401
|
| 225 |
+
import datasets # type: ignore # noqa: F401
|
| 226 |
+
return True
|
| 227 |
+
except ImportError:
|
| 228 |
+
return False
|
evaluation/store.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Cortex RAG β Evaluation Store (SQLite)
|
| 3 |
+
|
| 4 |
+
Two tables:
|
| 5 |
+
query_logs β one row per query: routing, CRAG grade, latency, chunk scores
|
| 6 |
+
eval_metrics β one row per query: RAGAS scores (written async after generation)
|
| 7 |
+
|
| 8 |
+
SQLite is the right choice here: zero infrastructure, works on Railway/Render
|
| 9 |
+
out of the box, and a dashboard corpus of ~10k queries fits in <50MB.
|
| 10 |
+
Swap to Postgres trivially later by changing the connection string.
|
| 11 |
+
|
| 12 |
+
The store is intentionally append-only. No deletes, no updates.
|
| 13 |
+
This preserves the full history for trend analysis in the dashboard.
|
| 14 |
+
"""
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import logging
|
| 19 |
+
import sqlite3
|
| 20 |
+
import time
|
| 21 |
+
from contextlib import contextmanager
|
| 22 |
+
from dataclasses import dataclass
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Optional
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
_DEFAULT_DB_PATH = Path("data/cortex_eval.db")
|
| 29 |
+
|
| 30 |
+
# ββ Schema βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
|
| 32 |
+
_DDL = """
|
| 33 |
+
CREATE TABLE IF NOT EXISTS query_logs (
|
| 34 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 35 |
+
timestamp REAL NOT NULL,
|
| 36 |
+
query TEXT NOT NULL,
|
| 37 |
+
intent TEXT,
|
| 38 |
+
strategies TEXT, -- JSON list
|
| 39 |
+
retriever_hits TEXT, -- JSON dict
|
| 40 |
+
crag_grade TEXT,
|
| 41 |
+
crag_rewritten INTEGER DEFAULT 0, -- bool
|
| 42 |
+
web_search_used INTEGER DEFAULT 0, -- bool
|
| 43 |
+
num_chunks INTEGER DEFAULT 0,
|
| 44 |
+
top_chunk_score REAL DEFAULT 0.0,
|
| 45 |
+
latency_ms REAL DEFAULT 0.0,
|
| 46 |
+
model TEXT,
|
| 47 |
+
extractor TEXT
|
| 48 |
+
);
|
| 49 |
+
|
| 50 |
+
CREATE TABLE IF NOT EXISTS eval_metrics (
|
| 51 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 52 |
+
query_log_id INTEGER NOT NULL REFERENCES query_logs(id),
|
| 53 |
+
timestamp REAL NOT NULL,
|
| 54 |
+
faithfulness REAL, -- 0-1: does answer contradict context?
|
| 55 |
+
answer_relevancy REAL, -- 0-1: does answer address the question?
|
| 56 |
+
context_precision REAL, -- 0-1: are retrieved chunks relevant?
|
| 57 |
+
context_utilisation REAL, -- 0-1: fraction of chunks cited in answer
|
| 58 |
+
mean_chunk_score REAL -- average retrieval score of final chunks
|
| 59 |
+
);
|
| 60 |
+
|
| 61 |
+
CREATE INDEX IF NOT EXISTS idx_query_logs_ts ON query_logs(timestamp);
|
| 62 |
+
CREATE INDEX IF NOT EXISTS idx_eval_metrics_id ON eval_metrics(query_log_id);
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ββ Dataclasses ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 67 |
+
|
| 68 |
+
@dataclass
|
| 69 |
+
class QueryLogEntry:
|
| 70 |
+
query: str
|
| 71 |
+
intent: str = ""
|
| 72 |
+
strategies: list[str] = None
|
| 73 |
+
retriever_hits: dict = None
|
| 74 |
+
crag_grade: str = ""
|
| 75 |
+
crag_rewritten: bool = False
|
| 76 |
+
web_search_used: bool = False
|
| 77 |
+
num_chunks: int = 0
|
| 78 |
+
top_chunk_score: float = 0.0
|
| 79 |
+
latency_ms: float = 0.0
|
| 80 |
+
model: str = ""
|
| 81 |
+
extractor: str = ""
|
| 82 |
+
|
| 83 |
+
def __post_init__(self):
|
| 84 |
+
if self.strategies is None:
|
| 85 |
+
self.strategies = []
|
| 86 |
+
if self.retriever_hits is None:
|
| 87 |
+
self.retriever_hits = {}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@dataclass
|
| 91 |
+
class EvalMetricEntry:
|
| 92 |
+
query_log_id: int
|
| 93 |
+
faithfulness: Optional[float] = None
|
| 94 |
+
answer_relevancy: Optional[float] = None
|
| 95 |
+
context_precision: Optional[float] = None
|
| 96 |
+
context_utilisation: Optional[float] = None
|
| 97 |
+
mean_chunk_score: Optional[float] = None
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# ββ Store ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 101 |
+
|
| 102 |
+
class EvalStore:
|
| 103 |
+
"""
|
| 104 |
+
Thread-safe SQLite-backed store for query logs and eval metrics.
|
| 105 |
+
|
| 106 |
+
Usage:
|
| 107 |
+
store = EvalStore()
|
| 108 |
+
log_id = store.log_query(entry)
|
| 109 |
+
store.log_metrics(EvalMetricEntry(query_log_id=log_id, faithfulness=0.92, ...))
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
def __init__(self, db_path: str | Path = _DEFAULT_DB_PATH) -> None:
|
| 113 |
+
self._path = Path(db_path)
|
| 114 |
+
self._path.parent.mkdir(parents=True, exist_ok=True)
|
| 115 |
+
self._init_db()
|
| 116 |
+
|
| 117 |
+
# ββ Write ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 118 |
+
|
| 119 |
+
def log_query(self, entry: QueryLogEntry) -> int:
|
| 120 |
+
"""Insert a query log row. Returns the new row id."""
|
| 121 |
+
with self._conn() as conn:
|
| 122 |
+
cur = conn.execute(
|
| 123 |
+
"""INSERT INTO query_logs
|
| 124 |
+
(timestamp, query, intent, strategies, retriever_hits,
|
| 125 |
+
crag_grade, crag_rewritten, web_search_used,
|
| 126 |
+
num_chunks, top_chunk_score, latency_ms, model, extractor)
|
| 127 |
+
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?)""",
|
| 128 |
+
(
|
| 129 |
+
time.time(),
|
| 130 |
+
entry.query,
|
| 131 |
+
entry.intent,
|
| 132 |
+
json.dumps(entry.strategies),
|
| 133 |
+
json.dumps(entry.retriever_hits),
|
| 134 |
+
entry.crag_grade,
|
| 135 |
+
int(entry.crag_rewritten),
|
| 136 |
+
int(entry.web_search_used),
|
| 137 |
+
entry.num_chunks,
|
| 138 |
+
entry.top_chunk_score,
|
| 139 |
+
entry.latency_ms,
|
| 140 |
+
entry.model,
|
| 141 |
+
entry.extractor,
|
| 142 |
+
),
|
| 143 |
+
)
|
| 144 |
+
return cur.lastrowid
|
| 145 |
+
|
| 146 |
+
def log_metrics(self, entry: EvalMetricEntry) -> None:
|
| 147 |
+
"""Insert an eval_metrics row."""
|
| 148 |
+
with self._conn() as conn:
|
| 149 |
+
conn.execute(
|
| 150 |
+
"""INSERT INTO eval_metrics
|
| 151 |
+
(query_log_id, timestamp, faithfulness, answer_relevancy,
|
| 152 |
+
context_precision, context_utilisation, mean_chunk_score)
|
| 153 |
+
VALUES (?,?,?,?,?,?,?)""",
|
| 154 |
+
(
|
| 155 |
+
entry.query_log_id,
|
| 156 |
+
time.time(),
|
| 157 |
+
entry.faithfulness,
|
| 158 |
+
entry.answer_relevancy,
|
| 159 |
+
entry.context_precision,
|
| 160 |
+
entry.context_utilisation,
|
| 161 |
+
entry.mean_chunk_score,
|
| 162 |
+
),
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# ββ Read βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 166 |
+
|
| 167 |
+
def get_recent_queries(self, limit: int = 100) -> list[dict]:
|
| 168 |
+
"""Last N query logs joined with their eval metrics (if available)."""
|
| 169 |
+
with self._conn() as conn:
|
| 170 |
+
rows = conn.execute(
|
| 171 |
+
"""SELECT q.id, q.timestamp, q.query, q.intent, q.strategies,
|
| 172 |
+
q.crag_grade, q.web_search_used, q.num_chunks,
|
| 173 |
+
q.top_chunk_score, q.latency_ms,
|
| 174 |
+
e.faithfulness, e.answer_relevancy,
|
| 175 |
+
e.context_precision, e.context_utilisation,
|
| 176 |
+
e.mean_chunk_score
|
| 177 |
+
FROM query_logs q
|
| 178 |
+
LEFT JOIN eval_metrics e ON e.query_log_id = q.id
|
| 179 |
+
ORDER BY q.timestamp DESC
|
| 180 |
+
LIMIT ?""",
|
| 181 |
+
(limit,),
|
| 182 |
+
).fetchall()
|
| 183 |
+
return [self._row_to_dict(r) for r in rows]
|
| 184 |
+
|
| 185 |
+
def get_metric_timeseries(self, days: int = 7) -> list[dict]:
|
| 186 |
+
"""
|
| 187 |
+
Hourly-bucketed metric averages over the last N days.
|
| 188 |
+
Used for the trend line chart in the dashboard.
|
| 189 |
+
"""
|
| 190 |
+
since = time.time() - days * 86400
|
| 191 |
+
with self._conn() as conn:
|
| 192 |
+
rows = conn.execute(
|
| 193 |
+
"""SELECT
|
| 194 |
+
CAST((q.timestamp - ?) / 3600 AS INTEGER) AS hour_bucket,
|
| 195 |
+
AVG(e.faithfulness) AS faithfulness,
|
| 196 |
+
AVG(e.answer_relevancy) AS answer_relevancy,
|
| 197 |
+
AVG(e.context_precision) AS context_precision,
|
| 198 |
+
AVG(e.mean_chunk_score) AS mean_chunk_score,
|
| 199 |
+
COUNT(*) AS query_count
|
| 200 |
+
FROM query_logs q
|
| 201 |
+
JOIN eval_metrics e ON e.query_log_id = q.id
|
| 202 |
+
WHERE q.timestamp > ?
|
| 203 |
+
GROUP BY hour_bucket
|
| 204 |
+
ORDER BY hour_bucket""",
|
| 205 |
+
(since, since),
|
| 206 |
+
).fetchall()
|
| 207 |
+
return [dict(zip(
|
| 208 |
+
["hour_bucket", "faithfulness", "answer_relevancy",
|
| 209 |
+
"context_precision", "mean_chunk_score", "query_count"], r
|
| 210 |
+
)) for r in rows]
|
| 211 |
+
|
| 212 |
+
def get_summary_stats(self) -> dict:
|
| 213 |
+
"""Aggregate stats for the dashboard header metrics."""
|
| 214 |
+
with self._conn() as conn:
|
| 215 |
+
total = conn.execute("SELECT COUNT(*) FROM query_logs").fetchone()[0]
|
| 216 |
+
with_metrics = conn.execute("SELECT COUNT(*) FROM eval_metrics").fetchone()[0]
|
| 217 |
+
avgs = conn.execute(
|
| 218 |
+
"""SELECT AVG(faithfulness), AVG(answer_relevancy),
|
| 219 |
+
AVG(context_precision), AVG(mean_chunk_score)
|
| 220 |
+
FROM eval_metrics"""
|
| 221 |
+
).fetchone()
|
| 222 |
+
grade_dist = conn.execute(
|
| 223 |
+
"""SELECT crag_grade, COUNT(*) as cnt
|
| 224 |
+
FROM query_logs WHERE crag_grade != ''
|
| 225 |
+
GROUP BY crag_grade"""
|
| 226 |
+
).fetchall()
|
| 227 |
+
strategy_dist = conn.execute(
|
| 228 |
+
"""SELECT strategies, COUNT(*) as cnt
|
| 229 |
+
FROM query_logs GROUP BY strategies"""
|
| 230 |
+
).fetchall()
|
| 231 |
+
avg_latency = conn.execute(
|
| 232 |
+
"SELECT AVG(latency_ms) FROM query_logs WHERE latency_ms > 0"
|
| 233 |
+
).fetchone()[0]
|
| 234 |
+
|
| 235 |
+
return {
|
| 236 |
+
"total_queries": total,
|
| 237 |
+
"evaluated_queries": with_metrics,
|
| 238 |
+
"avg_faithfulness": round(avgs[0] or 0, 3),
|
| 239 |
+
"avg_answer_relevancy": round(avgs[1] or 0, 3),
|
| 240 |
+
"avg_context_precision": round(avgs[2] or 0, 3),
|
| 241 |
+
"avg_chunk_score": round(avgs[3] or 0, 3),
|
| 242 |
+
"avg_latency_ms": round(avg_latency or 0, 1),
|
| 243 |
+
"crag_grade_dist": {r[0]: r[1] for r in grade_dist},
|
| 244 |
+
"strategy_dist": {r[0]: r[1] for r in strategy_dist},
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
# ββ Init βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 248 |
+
|
| 249 |
+
def _init_db(self) -> None:
|
| 250 |
+
with self._conn() as conn:
|
| 251 |
+
conn.executescript(_DDL)
|
| 252 |
+
logger.info("EvalStore ready at %s", self._path)
|
| 253 |
+
|
| 254 |
+
@contextmanager
|
| 255 |
+
def _conn(self):
|
| 256 |
+
conn = sqlite3.connect(self._path, timeout=10, check_same_thread=False)
|
| 257 |
+
conn.row_factory = sqlite3.Row
|
| 258 |
+
try:
|
| 259 |
+
yield conn
|
| 260 |
+
conn.commit()
|
| 261 |
+
except Exception:
|
| 262 |
+
conn.rollback()
|
| 263 |
+
raise
|
| 264 |
+
finally:
|
| 265 |
+
conn.close()
|
| 266 |
+
|
| 267 |
+
@staticmethod
|
| 268 |
+
def _row_to_dict(row) -> dict:
|
| 269 |
+
d = dict(row)
|
| 270 |
+
for key in ("strategies",):
|
| 271 |
+
if d.get(key):
|
| 272 |
+
try:
|
| 273 |
+
d[key] = json.loads(d[key])
|
| 274 |
+
except Exception:
|
| 275 |
+
pass
|
| 276 |
+
return d
|
generation/crag.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Cortex RAG β Corrective RAG (CRAG) Gate (Phase 3)
|
| 3 |
+
|
| 4 |
+
The problem CRAG solves
|
| 5 |
+
ββββββββββββββββββββββββ
|
| 6 |
+
Standard RAG always passes retrieved chunks to the LLM, even when:
|
| 7 |
+
- The query is ambiguous and the retrieved chunks are off-topic
|
| 8 |
+
- The knowledge base simply doesn't contain the answer
|
| 9 |
+
- The retrieved chunks contradict each other
|
| 10 |
+
|
| 11 |
+
In all three cases, the LLM will either hallucinate or produce a
|
| 12 |
+
confused answer. CRAG adds a grading step BEFORE generation:
|
| 13 |
+
|
| 14 |
+
ββββ GOOD βββββΊ Generator (proceed normally)
|
| 15 |
+
Query β Retrieve β€
|
| 16 |
+
ββββ POOR βββββΊ Rewrite query β Re-retrieve β Generator
|
| 17 |
+
ββββ ABSENT βββΊ Web search fallback β Generator
|
| 18 |
+
|
| 19 |
+
Grading
|
| 20 |
+
ββββββββ
|
| 21 |
+
An LLM-as-judge evaluates (query, retrieved_chunks) and returns:
|
| 22 |
+
{
|
| 23 |
+
"grade": "GOOD" | "POOR" | "ABSENT",
|
| 24 |
+
"relevance_score": 0.0β1.0,
|
| 25 |
+
"has_sufficient_context": true | false,
|
| 26 |
+
"reasoning": "..."
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
Grade definitions:
|
| 30 |
+
GOOD β chunks are relevant and sufficient for the query
|
| 31 |
+
POOR β chunks are partially relevant; try rewriting the query
|
| 32 |
+
ABSENT β knowledge base clearly doesn't contain the answer;
|
| 33 |
+
fall back to web search
|
| 34 |
+
|
| 35 |
+
Query rewriting
|
| 36 |
+
ββββββββββββββββ
|
| 37 |
+
When grade == POOR, we expand the query using chain-of-thought:
|
| 38 |
+
the grader's `reasoning` field (why did retrieval fail?) is fed
|
| 39 |
+
back as context for a rewrite prompt. This makes the rewrite
|
| 40 |
+
semantically targeted, not just rephrased.
|
| 41 |
+
|
| 42 |
+
Web search fallback
|
| 43 |
+
ββββββββββββββββββββ
|
| 44 |
+
When grade == ABSENT, we call Tavily (preferred) or DuckDuckGo
|
| 45 |
+
(no API key needed) and package the top-3 web results as synthetic
|
| 46 |
+
RetrievedChunk objects with source="web_search". These flow into
|
| 47 |
+
the same generator unchanged.
|
| 48 |
+
"""
|
| 49 |
+
from __future__ import annotations
|
| 50 |
+
|
| 51 |
+
import json
|
| 52 |
+
import logging
|
| 53 |
+
import re
|
| 54 |
+
from dataclasses import dataclass
|
| 55 |
+
from enum import Enum
|
| 56 |
+
from typing import Optional
|
| 57 |
+
|
| 58 |
+
from config import get_settings
|
| 59 |
+
from retrieval.dense import RetrievedChunk
|
| 60 |
+
|
| 61 |
+
logger = logging.getLogger(__name__)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# ββ Grade enum βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 65 |
+
|
| 66 |
+
class RetrievalGrade(str, Enum):
|
| 67 |
+
GOOD = "GOOD" # proceed to generation
|
| 68 |
+
POOR = "POOR" # rewrite query and re-retrieve
|
| 69 |
+
ABSENT = "ABSENT" # fall back to web search
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ββ Result dataclass βββββββββββββββββββββββββββββββββββββββββββ
|
| 73 |
+
|
| 74 |
+
@dataclass
|
| 75 |
+
class CRAGResult:
|
| 76 |
+
grade: RetrievalGrade
|
| 77 |
+
relevance_score: float
|
| 78 |
+
has_sufficient_context: bool
|
| 79 |
+
reasoning: str
|
| 80 |
+
final_chunks: list[RetrievedChunk] # chunks to pass to generator
|
| 81 |
+
rewritten_query: Optional[str] = None # set if grade was POOR
|
| 82 |
+
web_search_used: bool = False
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ββ Prompt templates βββββββββββββββββββββββββββββββββββββββββββ
|
| 86 |
+
|
| 87 |
+
_GRADER_PROMPT = """\
|
| 88 |
+
You are a retrieval quality judge. Given a user query and retrieved passages,
|
| 89 |
+
assess whether the passages contain sufficient information to answer the query.
|
| 90 |
+
|
| 91 |
+
Return ONLY a JSON object in this exact format (no markdown, no preamble):
|
| 92 |
+
{{
|
| 93 |
+
"grade": "<GOOD|POOR|ABSENT>",
|
| 94 |
+
"relevance_score": <float 0.0-1.0>,
|
| 95 |
+
"has_sufficient_context": <true|false>,
|
| 96 |
+
"reasoning": "<one sentence explaining your assessment>"
|
| 97 |
+
}}
|
| 98 |
+
|
| 99 |
+
Grades:
|
| 100 |
+
GOOD β passages are clearly relevant and contain enough information to answer
|
| 101 |
+
POOR β passages are partially relevant but incomplete or off-topic; retrieval should be retried
|
| 102 |
+
ABSENT β the knowledge base clearly does not contain information about this query
|
| 103 |
+
|
| 104 |
+
User query: {query}
|
| 105 |
+
|
| 106 |
+
Retrieved passages:
|
| 107 |
+
{passages}
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
_REWRITE_PROMPT = """\
|
| 111 |
+
A retrieval system failed to find good results for the following query.
|
| 112 |
+
The grader's feedback explains why the results were poor.
|
| 113 |
+
|
| 114 |
+
Original query: {query}
|
| 115 |
+
Grader feedback: {reasoning}
|
| 116 |
+
|
| 117 |
+
Rewrite the query to be more specific and likely to retrieve better results.
|
| 118 |
+
Apply these strategies: expand acronyms, add domain context, use alternative terms.
|
| 119 |
+
|
| 120 |
+
Return ONLY the rewritten query string, no explanation.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# ββ CRAG Gate ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 125 |
+
|
| 126 |
+
class CRAGGate:
|
| 127 |
+
"""
|
| 128 |
+
Corrective RAG gate that sits between retrieval and generation.
|
| 129 |
+
|
| 130 |
+
Usage (in orchestrator):
|
| 131 |
+
crag = CRAGGate()
|
| 132 |
+
result = crag.evaluate(
|
| 133 |
+
query=user_query,
|
| 134 |
+
chunks=retrieved_chunks,
|
| 135 |
+
retriever_fn=retriever.retrieve, # callable for re-retrieval
|
| 136 |
+
)
|
| 137 |
+
# result.final_chunks β pass to generator
|
| 138 |
+
# result.grade β log for evaluation dashboard
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
def __init__(self) -> None:
|
| 142 |
+
self._llm = None
|
| 143 |
+
|
| 144 |
+
# ββ Public API βββββββββββββββββββββββββββββββββββββββββββββ
|
| 145 |
+
|
| 146 |
+
def evaluate(
|
| 147 |
+
self,
|
| 148 |
+
query: str,
|
| 149 |
+
chunks: list[RetrievedChunk],
|
| 150 |
+
retriever_fn: Optional[callable] = None,
|
| 151 |
+
max_retries: int = 1,
|
| 152 |
+
) -> CRAGResult:
|
| 153 |
+
"""
|
| 154 |
+
Grade retrieved chunks and apply corrective action if needed.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
query: the user's original query
|
| 158 |
+
chunks: chunks returned by the retrieval pipeline
|
| 159 |
+
retriever_fn: callable(query: str) β list[RetrievedChunk]
|
| 160 |
+
used for re-retrieval on POOR grade
|
| 161 |
+
max_retries: max number of rewrite+re-retrieve cycles
|
| 162 |
+
"""
|
| 163 |
+
# Grade the initial retrieval
|
| 164 |
+
grade_result = self._grade(query, chunks)
|
| 165 |
+
logger.info(
|
| 166 |
+
"CRAG grade: %s (score=%.2f, sufficient=%s) β %s",
|
| 167 |
+
grade_result["grade"],
|
| 168 |
+
grade_result["relevance_score"],
|
| 169 |
+
grade_result["has_sufficient_context"],
|
| 170 |
+
grade_result["reasoning"][:80],
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
grade = RetrievalGrade(grade_result["grade"])
|
| 174 |
+
|
| 175 |
+
# ββ GOOD: pass through unchanged ββββββββββββββββββββββ
|
| 176 |
+
if grade == RetrievalGrade.GOOD:
|
| 177 |
+
return CRAGResult(
|
| 178 |
+
grade=grade,
|
| 179 |
+
relevance_score=grade_result["relevance_score"],
|
| 180 |
+
has_sufficient_context=True,
|
| 181 |
+
reasoning=grade_result["reasoning"],
|
| 182 |
+
final_chunks=chunks,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# ββ POOR: rewrite query and re-retrieve βββββββββββββββ
|
| 186 |
+
if grade == RetrievalGrade.POOR and retriever_fn and max_retries > 0:
|
| 187 |
+
rewritten = self._rewrite_query(query, grade_result["reasoning"])
|
| 188 |
+
logger.info("CRAG rewrite: '%s' β '%s'", query[:50], rewritten[:50])
|
| 189 |
+
|
| 190 |
+
try:
|
| 191 |
+
new_chunks = retriever_fn(rewritten)
|
| 192 |
+
# Re-grade the new results (once β no infinite loop)
|
| 193 |
+
new_grade = self._grade(rewritten, new_chunks)
|
| 194 |
+
return CRAGResult(
|
| 195 |
+
grade=RetrievalGrade(new_grade["grade"]),
|
| 196 |
+
relevance_score=new_grade["relevance_score"],
|
| 197 |
+
has_sufficient_context=new_grade["has_sufficient_context"],
|
| 198 |
+
reasoning=new_grade["reasoning"],
|
| 199 |
+
final_chunks=new_chunks or chunks, # fall back if retry also empty
|
| 200 |
+
rewritten_query=rewritten,
|
| 201 |
+
)
|
| 202 |
+
except Exception as exc:
|
| 203 |
+
logger.warning("Re-retrieval after rewrite failed: %s", exc)
|
| 204 |
+
# Fall through to returning original chunks with POOR grade
|
| 205 |
+
return CRAGResult(
|
| 206 |
+
grade=grade,
|
| 207 |
+
relevance_score=grade_result["relevance_score"],
|
| 208 |
+
has_sufficient_context=False,
|
| 209 |
+
reasoning=grade_result["reasoning"],
|
| 210 |
+
final_chunks=chunks,
|
| 211 |
+
rewritten_query=rewritten,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# ββ ABSENT: web search fallback ββββββββββββββββββββββββ
|
| 215 |
+
if grade == RetrievalGrade.ABSENT:
|
| 216 |
+
web_chunks = self._web_search_fallback(query)
|
| 217 |
+
if web_chunks:
|
| 218 |
+
return CRAGResult(
|
| 219 |
+
grade=grade,
|
| 220 |
+
relevance_score=0.0,
|
| 221 |
+
has_sufficient_context=True,
|
| 222 |
+
reasoning=grade_result["reasoning"],
|
| 223 |
+
final_chunks=web_chunks,
|
| 224 |
+
web_search_used=True,
|
| 225 |
+
)
|
| 226 |
+
# Web search also failed β return original chunks with warning
|
| 227 |
+
return CRAGResult(
|
| 228 |
+
grade=grade,
|
| 229 |
+
relevance_score=0.0,
|
| 230 |
+
has_sufficient_context=False,
|
| 231 |
+
reasoning=f"Knowledge base: {grade_result['reasoning']}. Web search also returned no results.",
|
| 232 |
+
final_chunks=chunks,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# Default: return original chunks unchanged
|
| 236 |
+
return CRAGResult(
|
| 237 |
+
grade=grade,
|
| 238 |
+
relevance_score=grade_result["relevance_score"],
|
| 239 |
+
has_sufficient_context=grade_result["has_sufficient_context"],
|
| 240 |
+
reasoning=grade_result["reasoning"],
|
| 241 |
+
final_chunks=chunks,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# ββ LLM grader ββββββββββββββββββββββββββββββββββββββββββββ
|
| 245 |
+
|
| 246 |
+
def _grade(self, query: str, chunks: list[RetrievedChunk]) -> dict:
|
| 247 |
+
"""Call LLM to grade retrieval quality. Returns parsed dict."""
|
| 248 |
+
if not chunks:
|
| 249 |
+
return {
|
| 250 |
+
"grade": "ABSENT",
|
| 251 |
+
"relevance_score": 0.0,
|
| 252 |
+
"has_sufficient_context": False,
|
| 253 |
+
"reasoning": "No chunks were retrieved.",
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
passages = "\n\n".join(
|
| 257 |
+
f"[{i}] {c.title}: {c.text[:400]}"
|
| 258 |
+
for i, c in enumerate(chunks[:5], 1)
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
try:
|
| 262 |
+
client = self._get_llm()
|
| 263 |
+
cfg = get_settings()
|
| 264 |
+
response = client.chat.completions.create(
|
| 265 |
+
model=cfg.groq_model,
|
| 266 |
+
messages=[{
|
| 267 |
+
"role": "user",
|
| 268 |
+
"content": _GRADER_PROMPT.format(query=query, passages=passages),
|
| 269 |
+
}],
|
| 270 |
+
temperature=0.0,
|
| 271 |
+
max_tokens=200,
|
| 272 |
+
)
|
| 273 |
+
raw = response.choices[0].message.content or "{}"
|
| 274 |
+
return self._parse_grade(raw)
|
| 275 |
+
|
| 276 |
+
except Exception as exc:
|
| 277 |
+
logger.warning("CRAG grader LLM call failed: %s", exc)
|
| 278 |
+
# Safe default: assume GOOD to avoid blocking the pipeline
|
| 279 |
+
return {
|
| 280 |
+
"grade": "GOOD",
|
| 281 |
+
"relevance_score": 0.5,
|
| 282 |
+
"has_sufficient_context": True,
|
| 283 |
+
"reasoning": f"Grader unavailable ({exc}); passing through.",
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
def _parse_grade(self, raw: str) -> dict:
|
| 287 |
+
raw = raw.strip()
|
| 288 |
+
if raw.startswith("```"):
|
| 289 |
+
raw = re.sub(r"^```[a-z]*\n?", "", raw)
|
| 290 |
+
raw = re.sub(r"\n?```$", "", raw)
|
| 291 |
+
try:
|
| 292 |
+
data = json.loads(raw)
|
| 293 |
+
except json.JSONDecodeError:
|
| 294 |
+
return {
|
| 295 |
+
"grade": "GOOD", "relevance_score": 0.5,
|
| 296 |
+
"has_sufficient_context": True, "reasoning": "Parse error.",
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
grade_str = data.get("grade", "GOOD").upper()
|
| 300 |
+
if grade_str not in {"GOOD", "POOR", "ABSENT"}:
|
| 301 |
+
grade_str = "GOOD"
|
| 302 |
+
|
| 303 |
+
return {
|
| 304 |
+
"grade": grade_str,
|
| 305 |
+
"relevance_score": float(data.get("relevance_score", 0.5)),
|
| 306 |
+
"has_sufficient_context": bool(data.get("has_sufficient_context", True)),
|
| 307 |
+
"reasoning": str(data.get("reasoning", "")),
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
# ββ Query rewriter ββββββββββββββββββββββββββββββββββββββββ
|
| 311 |
+
|
| 312 |
+
def _rewrite_query(self, original_query: str, reasoning: str) -> str:
|
| 313 |
+
try:
|
| 314 |
+
client = self._get_llm()
|
| 315 |
+
cfg = get_settings()
|
| 316 |
+
response = client.chat.completions.create(
|
| 317 |
+
model=cfg.groq_model,
|
| 318 |
+
messages=[{
|
| 319 |
+
"role": "user",
|
| 320 |
+
"content": _REWRITE_PROMPT.format(
|
| 321 |
+
query=original_query, reasoning=reasoning
|
| 322 |
+
),
|
| 323 |
+
}],
|
| 324 |
+
temperature=0.3,
|
| 325 |
+
max_tokens=128,
|
| 326 |
+
)
|
| 327 |
+
rewritten = (response.choices[0].message.content or "").strip()
|
| 328 |
+
return rewritten if rewritten else original_query
|
| 329 |
+
except Exception as exc:
|
| 330 |
+
logger.warning("Query rewrite failed: %s", exc)
|
| 331 |
+
return original_query
|
| 332 |
+
|
| 333 |
+
# ββ Web search fallback βββββββββββββββββββββββββββββββββββ
|
| 334 |
+
|
| 335 |
+
def _web_search_fallback(self, query: str) -> list[RetrievedChunk]:
|
| 336 |
+
"""
|
| 337 |
+
Try Tavily first (better quality), then DuckDuckGo (no API key).
|
| 338 |
+
Returns synthetic RetrievedChunk objects from web results.
|
| 339 |
+
"""
|
| 340 |
+
chunks = self._tavily_search(query) or self._duckduckgo_search(query)
|
| 341 |
+
if chunks:
|
| 342 |
+
logger.info("CRAG web fallback: %d results for '%s'", len(chunks), query[:50])
|
| 343 |
+
return chunks
|
| 344 |
+
|
| 345 |
+
def _tavily_search(self, query: str) -> list[RetrievedChunk]:
|
| 346 |
+
try:
|
| 347 |
+
from tavily import TavilyClient # type: ignore
|
| 348 |
+
cfg = get_settings()
|
| 349 |
+
api_key = cfg.tavily_api_key
|
| 350 |
+
if not api_key:
|
| 351 |
+
return []
|
| 352 |
+
client = TavilyClient(api_key=api_key)
|
| 353 |
+
results = client.search(query, max_results=3)
|
| 354 |
+
return [
|
| 355 |
+
self._web_result_to_chunk(r.get("content", ""), r.get("url", ""), r.get("title", "Web"))
|
| 356 |
+
for r in results.get("results", [])
|
| 357 |
+
if r.get("content")
|
| 358 |
+
]
|
| 359 |
+
except Exception:
|
| 360 |
+
return []
|
| 361 |
+
|
| 362 |
+
def _duckduckgo_search(self, query: str) -> list[RetrievedChunk]:
|
| 363 |
+
try:
|
| 364 |
+
from duckduckgo_search import DDGS # type: ignore
|
| 365 |
+
results = []
|
| 366 |
+
with DDGS() as ddgs:
|
| 367 |
+
for r in ddgs.text(query, max_results=3):
|
| 368 |
+
results.append(
|
| 369 |
+
self._web_result_to_chunk(
|
| 370 |
+
r.get("body", ""), r.get("href", ""), r.get("title", "Web")
|
| 371 |
+
)
|
| 372 |
+
)
|
| 373 |
+
return results
|
| 374 |
+
except Exception:
|
| 375 |
+
return []
|
| 376 |
+
|
| 377 |
+
@staticmethod
|
| 378 |
+
def _web_result_to_chunk(text: str, url: str, title: str) -> RetrievedChunk:
|
| 379 |
+
import hashlib
|
| 380 |
+
cid = hashlib.sha256(url.encode()).hexdigest()[:16]
|
| 381 |
+
return RetrievedChunk(
|
| 382 |
+
chunk_id=cid,
|
| 383 |
+
doc_id="web",
|
| 384 |
+
source=url,
|
| 385 |
+
title=title,
|
| 386 |
+
text=text[:1500],
|
| 387 |
+
parent_text=text[:1500],
|
| 388 |
+
chunk_index=0,
|
| 389 |
+
score=0.6, # neutral score for web results
|
| 390 |
+
retriever="web_search",
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# ββ Groq client βββββββββββββββββββββββββββββββββββββββββββ
|
| 394 |
+
|
| 395 |
+
def _get_llm(self):
|
| 396 |
+
if self._llm is None:
|
| 397 |
+
cfg = get_settings()
|
| 398 |
+
if not cfg.groq_api_key:
|
| 399 |
+
raise RuntimeError("GROQ_API_KEY not set")
|
| 400 |
+
from groq import Groq # type: ignore
|
| 401 |
+
self._llm = Groq(api_key=cfg.groq_api_key)
|
| 402 |
+
return self._llm
|
ingestion/pipeline.py
CHANGED
|
@@ -22,6 +22,7 @@ from ingestion.document_loader import Document, DocumentLoader
|
|
| 22 |
from retrieval.embedder import Embedder
|
| 23 |
from retrieval.dense import MilvusStore
|
| 24 |
from retrieval.bm25 import BM25Retriever
|
|
|
|
| 25 |
|
| 26 |
logger = logging.getLogger(__name__)
|
| 27 |
|
|
@@ -43,12 +44,15 @@ class IngestionPipeline:
|
|
| 43 |
embedder: Optional[Embedder] = None,
|
| 44 |
store: Optional[MilvusStore] = None,
|
| 45 |
bm25: Optional[BM25Retriever] = None,
|
|
|
|
|
|
|
| 46 |
) -> None:
|
| 47 |
self._loader = loader or DocumentLoader()
|
| 48 |
self._embedder = embedder or Embedder()
|
| 49 |
self._chunker = chunker or SemanticChunker(embedder=self._embedder)
|
| 50 |
self._store = store or MilvusStore(embedder=self._embedder)
|
| 51 |
self._bm25 = bm25 or BM25Retriever()
|
|
|
|
| 52 |
|
| 53 |
# ββ Public βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
|
|
@@ -133,7 +137,16 @@ class IngestionPipeline:
|
|
| 133 |
except Exception as exc:
|
| 134 |
logger.error("BM25 indexing failed: %s", exc)
|
| 135 |
stats["errors"].append({"source": "bm25_index", "error": str(exc)})
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
elapsed = time.perf_counter() - t0
|
| 138 |
logger.info(
|
| 139 |
"Ingestion complete in %.1fs β %d docs, %d chunks stored.",
|
|
|
|
| 22 |
from retrieval.embedder import Embedder
|
| 23 |
from retrieval.dense import MilvusStore
|
| 24 |
from retrieval.bm25 import BM25Retriever
|
| 25 |
+
from retrieval.graph_builder import KnowledgeGraphBuilder
|
| 26 |
|
| 27 |
logger = logging.getLogger(__name__)
|
| 28 |
|
|
|
|
| 44 |
embedder: Optional[Embedder] = None,
|
| 45 |
store: Optional[MilvusStore] = None,
|
| 46 |
bm25: Optional[BM25Retriever] = None,
|
| 47 |
+
graph: Optional[KnowledgeGraphBuilder] = None,
|
| 48 |
+
|
| 49 |
) -> None:
|
| 50 |
self._loader = loader or DocumentLoader()
|
| 51 |
self._embedder = embedder or Embedder()
|
| 52 |
self._chunker = chunker or SemanticChunker(embedder=self._embedder)
|
| 53 |
self._store = store or MilvusStore(embedder=self._embedder)
|
| 54 |
self._bm25 = bm25 or BM25Retriever()
|
| 55 |
+
self._graph = graph or KnowledgeGraphBuilder()
|
| 56 |
|
| 57 |
# ββ Public βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 58 |
|
|
|
|
| 137 |
except Exception as exc:
|
| 138 |
logger.error("BM25 indexing failed: %s", exc)
|
| 139 |
stats["errors"].append({"source": "bm25_index", "error": str(exc)})
|
| 140 |
+
|
| 141 |
+
# ββ Build knowledge graph (NER + relations) β Phase 3 ββ
|
| 142 |
+
try:
|
| 143 |
+
graph_stats = self._graph.process_chunks(all_chunks)
|
| 144 |
+
stats["graph_entities"] = graph_stats.get("entities", 0)
|
| 145 |
+
stats["graph_triples"] = graph_stats.get("triples", 0)
|
| 146 |
+
except Exception as exc:
|
| 147 |
+
logger.error("Graph extraction failed: %s", exc)
|
| 148 |
+
stats["errors"].append({"source": "graph_build", "error": str(exc)})
|
| 149 |
+
|
| 150 |
elapsed = time.perf_counter() - t0
|
| 151 |
logger.info(
|
| 152 |
"Ingestion complete in %.1fs β %d docs, %d chunks stored.",
|
retrieval/cache.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Cortex RAG β Retrieval Cache (Redis, Phase 4)
|
| 3 |
+
|
| 4 |
+
What gets cached
|
| 5 |
+
βββββββββββββββββ
|
| 6 |
+
The output of the full retrieval pipeline β after RRF fusion and
|
| 7 |
+
cross-encoder reranking β is serialised and stored in Redis with a
|
| 8 |
+
configurable TTL (default 1 hour).
|
| 9 |
+
|
| 10 |
+
Cache key: SHA-256 of (query.lower().strip() + str(top_k))
|
| 11 |
+
This means the same query with different capitalisation or trailing
|
| 12 |
+
spaces hits the same cache entry, which is almost always correct for RAG.
|
| 13 |
+
|
| 14 |
+
What does NOT get cached
|
| 15 |
+
βββββββββββββββββββββββββ
|
| 16 |
+
CRAG evaluation and generation are NOT cached. The CRAG grade depends
|
| 17 |
+
on the current state of the knowledge base (which changes after ingestion),
|
| 18 |
+
and generation is fast enough (streaming) that caching it adds complexity
|
| 19 |
+
without meaningful latency savings.
|
| 20 |
+
|
| 21 |
+
Graceful degradation
|
| 22 |
+
βββββββββββββββββββββ
|
| 23 |
+
If Redis is unreachable on startup, the cache silently disables itself
|
| 24 |
+
and logs a warning. Every query falls through to the live retrieval
|
| 25 |
+
pipeline unchanged. No exceptions surface to the user.
|
| 26 |
+
|
| 27 |
+
This means you can develop without Redis running locally and only enable
|
| 28 |
+
it in production (Railway, Render) where Redis add-ons are available.
|
| 29 |
+
"""
|
| 30 |
+
from __future__ import annotations
|
| 31 |
+
|
| 32 |
+
import hashlib
|
| 33 |
+
import json
|
| 34 |
+
import logging
|
| 35 |
+
from typing import Optional
|
| 36 |
+
|
| 37 |
+
from retrieval.dense import RetrievedChunk
|
| 38 |
+
from retrieval.orchestrator import MultiStrategyRetriever, RetrievalResult
|
| 39 |
+
from retrieval.router import QueryIntent, RoutingDecision
|
| 40 |
+
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _make_cache_key(query: str, top_k: int) -> str:
|
| 45 |
+
raw = f"{query.lower().strip()}:{top_k}"
|
| 46 |
+
return "cortex:retrieval:" + hashlib.sha256(raw.encode()).hexdigest()[:24]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _serialise_result(result: RetrievalResult) -> str:
|
| 50 |
+
"""JSON-serialise a RetrievalResult for Redis storage."""
|
| 51 |
+
return json.dumps({
|
| 52 |
+
"chunks": [
|
| 53 |
+
{
|
| 54 |
+
"chunk_id": c.chunk_id,
|
| 55 |
+
"doc_id": c.doc_id,
|
| 56 |
+
"source": c.source,
|
| 57 |
+
"title": c.title,
|
| 58 |
+
"text": c.text,
|
| 59 |
+
"parent_text": c.parent_text,
|
| 60 |
+
"chunk_index": c.chunk_index,
|
| 61 |
+
"score": c.score,
|
| 62 |
+
"retriever": c.retriever,
|
| 63 |
+
}
|
| 64 |
+
for c in result.chunks
|
| 65 |
+
],
|
| 66 |
+
"decision": {
|
| 67 |
+
"intent": result.decision.intent.value,
|
| 68 |
+
"strategies": result.decision.strategies,
|
| 69 |
+
"confidence": result.decision.confidence,
|
| 70 |
+
"reasoning": result.decision.reasoning,
|
| 71 |
+
},
|
| 72 |
+
"retriever_hits": result.retriever_hits,
|
| 73 |
+
})
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _deserialise_result(raw: str) -> RetrievalResult:
|
| 77 |
+
"""Reconstruct a RetrievalResult from its JSON representation."""
|
| 78 |
+
data = json.loads(raw)
|
| 79 |
+
|
| 80 |
+
chunks = [
|
| 81 |
+
RetrievedChunk(
|
| 82 |
+
chunk_id=c["chunk_id"],
|
| 83 |
+
doc_id=c["doc_id"],
|
| 84 |
+
source=c["source"],
|
| 85 |
+
title=c["title"],
|
| 86 |
+
text=c["text"],
|
| 87 |
+
parent_text=c["parent_text"],
|
| 88 |
+
chunk_index=c["chunk_index"],
|
| 89 |
+
score=c["score"],
|
| 90 |
+
retriever=c["retriever"],
|
| 91 |
+
)
|
| 92 |
+
for c in data["chunks"]
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
d = data["decision"]
|
| 96 |
+
decision = RoutingDecision(
|
| 97 |
+
intent=QueryIntent(d["intent"]),
|
| 98 |
+
strategies=d["strategies"],
|
| 99 |
+
confidence=d["confidence"],
|
| 100 |
+
reasoning=d["reasoning"],
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
return RetrievalResult(
|
| 104 |
+
chunks=chunks,
|
| 105 |
+
decision=decision,
|
| 106 |
+
retriever_hits=data.get("retriever_hits", {}),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class CachedRetriever:
|
| 111 |
+
"""
|
| 112 |
+
Drop-in wrapper around MultiStrategyRetriever that adds Redis caching.
|
| 113 |
+
|
| 114 |
+
Usage (replaces MultiStrategyRetriever in api/main.py):
|
| 115 |
+
retriever = CachedRetriever(MultiStrategyRetriever(...))
|
| 116 |
+
result = retriever.retrieve(query)
|
| 117 |
+
print(retriever.cache_stats()) # {"hits": 3, "misses": 7, "enabled": True}
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
inner: MultiStrategyRetriever,
|
| 123 |
+
ttl_seconds: Optional[int] = None,
|
| 124 |
+
) -> None:
|
| 125 |
+
self._inner = inner
|
| 126 |
+
self._redis = self._connect_redis()
|
| 127 |
+
self._ttl = ttl_seconds or self._default_ttl()
|
| 128 |
+
self._hits = 0
|
| 129 |
+
self._misses = 0
|
| 130 |
+
|
| 131 |
+
# ββ Public API (matches MultiStrategyRetriever interface) ββ
|
| 132 |
+
|
| 133 |
+
def retrieve(
|
| 134 |
+
self,
|
| 135 |
+
query: str,
|
| 136 |
+
top_k_candidates: Optional[int] = None,
|
| 137 |
+
final_top_k: Optional[int] = None,
|
| 138 |
+
) -> RetrievalResult:
|
| 139 |
+
"""
|
| 140 |
+
Retrieve with cache. Falls through to live retrieval on miss or error.
|
| 141 |
+
"""
|
| 142 |
+
from config import get_settings
|
| 143 |
+
cfg = get_settings()
|
| 144 |
+
k = final_top_k or cfg.final_top_k
|
| 145 |
+
key = _make_cache_key(query, k)
|
| 146 |
+
|
| 147 |
+
# ββ Cache lookup βββββββββββββββββββββββββββββββββββββββ
|
| 148 |
+
if self._redis:
|
| 149 |
+
try:
|
| 150 |
+
cached = self._redis.get(key)
|
| 151 |
+
if cached:
|
| 152 |
+
self._hits += 1
|
| 153 |
+
logger.debug("Cache HIT for query: %sβ¦", query[:40])
|
| 154 |
+
result = _deserialise_result(cached)
|
| 155 |
+
result.from_cache = True
|
| 156 |
+
return result
|
| 157 |
+
except Exception as exc:
|
| 158 |
+
logger.warning("Redis GET failed: %s β falling through.", exc)
|
| 159 |
+
|
| 160 |
+
# ββ Cache miss: live retrieval βββββββββββββββββββββββββ
|
| 161 |
+
self._misses += 1
|
| 162 |
+
logger.debug("Cache MISS for query: %sβ¦", query[:40])
|
| 163 |
+
result = self._inner.retrieve(query, top_k_candidates, final_top_k)
|
| 164 |
+
result.from_cache = False
|
| 165 |
+
|
| 166 |
+
# ββ Write to cache βββββββββββββββββββββββββββββββββββββ
|
| 167 |
+
if self._redis and not result.empty:
|
| 168 |
+
try:
|
| 169 |
+
self._redis.setex(key, self._ttl, _serialise_result(result))
|
| 170 |
+
except Exception as exc:
|
| 171 |
+
logger.warning("Redis SET failed: %s", exc)
|
| 172 |
+
|
| 173 |
+
return result
|
| 174 |
+
|
| 175 |
+
def invalidate(self, query: str, top_k: int) -> bool:
|
| 176 |
+
"""Manually invalidate a cache entry (e.g. after re-ingestion)."""
|
| 177 |
+
if not self._redis:
|
| 178 |
+
return False
|
| 179 |
+
try:
|
| 180 |
+
return bool(self._redis.delete(_make_cache_key(query, top_k)))
|
| 181 |
+
except Exception:
|
| 182 |
+
return False
|
| 183 |
+
|
| 184 |
+
def flush_all(self) -> int:
|
| 185 |
+
"""Delete all Cortex cache keys. Returns count deleted."""
|
| 186 |
+
if not self._redis:
|
| 187 |
+
return 0
|
| 188 |
+
try:
|
| 189 |
+
keys = self._redis.keys("cortex:retrieval:*")
|
| 190 |
+
if keys:
|
| 191 |
+
return self._redis.delete(*keys)
|
| 192 |
+
return 0
|
| 193 |
+
except Exception:
|
| 194 |
+
return 0
|
| 195 |
+
|
| 196 |
+
def cache_stats(self) -> dict:
|
| 197 |
+
total = self._hits + self._misses
|
| 198 |
+
return {
|
| 199 |
+
"enabled": self._redis is not None,
|
| 200 |
+
"hits": self._hits,
|
| 201 |
+
"misses": self._misses,
|
| 202 |
+
"hit_rate": round(self._hits / total, 3) if total else 0.0,
|
| 203 |
+
"ttl_s": self._ttl,
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
# ββ Pass-through for orchestrator methods ββββββββββββββββββ
|
| 207 |
+
|
| 208 |
+
def index_chunks(self, chunks: list) -> int:
|
| 209 |
+
return self._inner.index_chunks(chunks)
|
| 210 |
+
|
| 211 |
+
def build_graph(self, chunks: list) -> dict:
|
| 212 |
+
return self._inner.build_graph(chunks)
|
| 213 |
+
|
| 214 |
+
@property
|
| 215 |
+
def graph_builder(self):
|
| 216 |
+
return self._inner.graph_builder
|
| 217 |
+
|
| 218 |
+
# ββ Redis connection βββββββββββββββββββββββββββββββββββββββ
|
| 219 |
+
|
| 220 |
+
@staticmethod
|
| 221 |
+
def _connect_redis():
|
| 222 |
+
from config import get_settings
|
| 223 |
+
cfg = get_settings()
|
| 224 |
+
url = getattr(cfg, "redis_url", "redis://localhost:6379")
|
| 225 |
+
try:
|
| 226 |
+
import redis # type: ignore
|
| 227 |
+
client = redis.from_url(url, socket_connect_timeout=2, decode_responses=True)
|
| 228 |
+
client.ping()
|
| 229 |
+
logger.info("Redis cache connected at %s", url)
|
| 230 |
+
return client
|
| 231 |
+
except ImportError:
|
| 232 |
+
logger.info("redis-py not installed β cache disabled. pip install redis")
|
| 233 |
+
return None
|
| 234 |
+
except Exception as exc:
|
| 235 |
+
logger.warning("Redis unavailable (%s) β cache disabled.", exc)
|
| 236 |
+
return None
|
| 237 |
+
|
| 238 |
+
@staticmethod
|
| 239 |
+
def _default_ttl() -> int:
|
| 240 |
+
from config import get_settings
|
| 241 |
+
return getattr(get_settings(), "cache_ttl_seconds", 3600)
|
retrieval/graph_builder.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Cortex RAG β Knowledge Graph Builder (Phase 3)
|
| 3 |
+
|
| 4 |
+
What this does
|
| 5 |
+
ββββββββββββββ
|
| 6 |
+
During ingestion, every chunk is processed to extract:
|
| 7 |
+
1. Named entities (spaCy NER: PERSON, ORG, WORK_OF_ART, PRODUCT, β¦)
|
| 8 |
+
2. Relations (few-shot LLM: subject β predicate β object triples)
|
| 9 |
+
|
| 10 |
+
These are assembled into a NetworkX undirected graph where:
|
| 11 |
+
- Nodes = entities (label + type + first-seen source)
|
| 12 |
+
- Edges = relations (predicate label + list of source chunk_ids)
|
| 13 |
+
|
| 14 |
+
Each node also carries a list of chunk_ids it appeared in, so the
|
| 15 |
+
graph retriever can map entity β chunks without an extra lookup.
|
| 16 |
+
|
| 17 |
+
The graph is persisted as a JSON file (graphs are small β a 100-doc
|
| 18 |
+
corpus typically has <10k nodes). On reload the full graph is
|
| 19 |
+
reconstructed in seconds from the JSON.
|
| 20 |
+
|
| 21 |
+
ββββββββββββββ
|
| 22 |
+
(Phase 3, refactored)
|
| 23 |
+
|
| 24 |
+
The builder is now responsible ONLY for:
|
| 25 |
+
- spaCy NER (entities are always extracted the same way)
|
| 26 |
+
- Assembling triples into a NetworkX graph
|
| 27 |
+
- Persisting / loading the graph
|
| 28 |
+
|
| 29 |
+
Relation extraction is delegated to a RelationExtractor strategy:
|
| 30 |
+
- REBELExtractor (default) β local model, no API calls
|
| 31 |
+
- LLMExtractor β Groq, free-form predicates
|
| 32 |
+
|
| 33 |
+
Switch via .env:
|
| 34 |
+
GRAPH_EXTRACTOR=rebel # default, recommended
|
| 35 |
+
GRAPH_EXTRACTOR=llm # original method
|
| 36 |
+
|
| 37 |
+
Or pass explicitly:
|
| 38 |
+
builder = KnowledgeGraphBuilder(extractor=LLMExtractor())
|
| 39 |
+
|
| 40 |
+
"""
|
| 41 |
+
from __future__ import annotations
|
| 42 |
+
|
| 43 |
+
import json
|
| 44 |
+
import logging
|
| 45 |
+
|
| 46 |
+
from pathlib import Path
|
| 47 |
+
from typing import Optional
|
| 48 |
+
|
| 49 |
+
import networkx as nx
|
| 50 |
+
|
| 51 |
+
from ingestion.chunker import Chunk
|
| 52 |
+
from retrieval.relation_extractors import (
|
| 53 |
+
RelationExtractor,
|
| 54 |
+
Triple,
|
| 55 |
+
build_extractor,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
logger = logging.getLogger(__name__)
|
| 59 |
+
|
| 60 |
+
_DEFAULT_GRAPH_PATH = Path("data/knowledge_graph.json")
|
| 61 |
+
|
| 62 |
+
# spaCy entity types we care about for RAG
|
| 63 |
+
_ENTITY_TYPES = {
|
| 64 |
+
"PERSON", "ORG", "GPE", "PRODUCT", "WORK_OF_ART",
|
| 65 |
+
"EVENT", "LAW", "NORP", "FAC", "LOC",
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
class KnowledgeGraphBuilder:
|
| 69 |
+
"""
|
| 70 |
+
Builds and maintains the knowledge graph.
|
| 71 |
+
|
| 72 |
+
Usage (at ingestion time):
|
| 73 |
+
# REBEL (default β no API calls)
|
| 74 |
+
builder = KnowledgeGraphBuilder()
|
| 75 |
+
builder.process_chunks(chunks)
|
| 76 |
+
|
| 77 |
+
# LLM method (original)
|
| 78 |
+
from retrieval.relation_extractors import LLMExtractor
|
| 79 |
+
builder = KnowledgeGraphBuilder(extractor=LLMExtractor())
|
| 80 |
+
builder.process_chunks(chunks)
|
| 81 |
+
|
| 82 |
+
Usage (at query time):
|
| 83 |
+
builder = KnowledgeGraphBuilder()
|
| 84 |
+
G = builder.graph # loaded from disk automatically
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
graph_path: str | Path = _DEFAULT_GRAPH_PATH,
|
| 90 |
+
extractor: Optional[RelationExtractor] = None,
|
| 91 |
+
) -> None:
|
| 92 |
+
self._path = Path(graph_path)
|
| 93 |
+
self._graph: nx.Graph = nx.Graph()
|
| 94 |
+
# If no extractor is injected, build_extractor() reads GRAPH_EXTRACTOR from .env
|
| 95 |
+
self._extractor: RelationExtractor = extractor or build_extractor()
|
| 96 |
+
self._nlp = None
|
| 97 |
+
self._load_if_exists()
|
| 98 |
+
logger.info(
|
| 99 |
+
"KnowledgeGraphBuilder ready (extractor=%s)", self._extractor.name
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# ββ Public API βββββββββββββββββββββββββββββββββββββββββββββ
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def graph(self) -> nx.Graph:
|
| 106 |
+
return self._graph
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def extractor_name(self) -> str:
|
| 110 |
+
return self._extractor.name
|
| 111 |
+
|
| 112 |
+
def process_chunks(self, chunks: list[Chunk]) -> dict:
|
| 113 |
+
"""
|
| 114 |
+
Extract entities and relations from chunks; update and save graph.
|
| 115 |
+
Uses the configured extractor's extract_batch() for efficiency.
|
| 116 |
+
Returns stats dict.
|
| 117 |
+
"""
|
| 118 |
+
if not chunks:
|
| 119 |
+
return {"chunks": 0, "entities": 0, "triples": 0, "errors": 0}
|
| 120 |
+
|
| 121 |
+
stats = {"chunks": len(chunks), "entities": 0, "triples": 0, "errors": 0}
|
| 122 |
+
|
| 123 |
+
# ββ Batch relation extraction ββββββββββββββββββββββββββ
|
| 124 |
+
# REBEL processes all chunks in one forward pass.
|
| 125 |
+
# LLM falls back to sequential (one API call per chunk).
|
| 126 |
+
try:
|
| 127 |
+
triple_map = self._extractor.extract_batch(chunks)
|
| 128 |
+
except Exception as exc:
|
| 129 |
+
logger.error("Batch extraction failed, falling back to sequential: %s", exc)
|
| 130 |
+
triple_map = {}
|
| 131 |
+
for chunk in chunks:
|
| 132 |
+
try:
|
| 133 |
+
triple_map[chunk.chunk_id] = self._extractor.extract(chunk)
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.warning("Extraction failed for %s: %s", chunk.chunk_id, e)
|
| 136 |
+
triple_map[chunk.chunk_id] = []
|
| 137 |
+
stats["errors"] += 1
|
| 138 |
+
|
| 139 |
+
# ββ Entity extraction + graph update βββββββββββββββββββ
|
| 140 |
+
for chunk in chunks:
|
| 141 |
+
try:
|
| 142 |
+
entities = self._extract_entities(chunk.text)
|
| 143 |
+
triples = triple_map.get(chunk.chunk_id, [])
|
| 144 |
+
|
| 145 |
+
self._add_entities_to_graph(entities, chunk)
|
| 146 |
+
self._add_triples_to_graph(triples)
|
| 147 |
+
|
| 148 |
+
stats["entities"] += len(entities)
|
| 149 |
+
stats["triples"] += len(triples)
|
| 150 |
+
|
| 151 |
+
except Exception as exc:
|
| 152 |
+
logger.warning("Graph update failed for chunk %s: %s", chunk.chunk_id, exc)
|
| 153 |
+
stats["errors"] += 1
|
| 154 |
+
|
| 155 |
+
self.save()
|
| 156 |
+
logger.info(
|
| 157 |
+
"Graph updated via %s: +%d entities, +%d triples (nodes=%d, edges=%d)",
|
| 158 |
+
self._extractor.name,
|
| 159 |
+
stats["entities"], stats["triples"],
|
| 160 |
+
self._graph.number_of_nodes(), self._graph.number_of_edges(),
|
| 161 |
+
)
|
| 162 |
+
return stats
|
| 163 |
+
|
| 164 |
+
def save(self) -> None:
|
| 165 |
+
self._path.parent.mkdir(parents=True, exist_ok=True)
|
| 166 |
+
data = nx.node_link_data(self._graph)
|
| 167 |
+
with open(self._path, "w") as fh:
|
| 168 |
+
json.dump(data, fh, indent=2)
|
| 169 |
+
logger.debug("Graph saved to %s", self._path)
|
| 170 |
+
|
| 171 |
+
def stats(self) -> dict:
|
| 172 |
+
return {
|
| 173 |
+
"nodes": self._graph.number_of_nodes(),
|
| 174 |
+
"edges": self._graph.number_of_edges(),
|
| 175 |
+
"extractor": self._extractor.name,
|
| 176 |
+
"graph_path": str(self._path),
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
# ββ Entity extraction (always spaCy β same for both methods) β
|
| 180 |
+
|
| 181 |
+
def _extract_entities(self, text: str) -> list[tuple[str, str]]:
|
| 182 |
+
nlp = self._get_nlp()
|
| 183 |
+
doc = nlp(text[:10_000])
|
| 184 |
+
|
| 185 |
+
seen: set[str] = set()
|
| 186 |
+
entities: list[tuple[str, str]] = []
|
| 187 |
+
for ent in doc.ents:
|
| 188 |
+
if ent.label_ not in _ENTITY_TYPES:
|
| 189 |
+
continue
|
| 190 |
+
normalised = ent.text.strip().title()
|
| 191 |
+
if normalised in seen or len(normalised) < 2:
|
| 192 |
+
continue
|
| 193 |
+
seen.add(normalised)
|
| 194 |
+
entities.append((normalised, ent.label_))
|
| 195 |
+
return entities
|
| 196 |
+
|
| 197 |
+
# ββ Graph construction (shared by both methods) ββββββββββββ
|
| 198 |
+
|
| 199 |
+
def _add_entities_to_graph(
|
| 200 |
+
self, entities: list[tuple[str, str]], chunk: Chunk
|
| 201 |
+
) -> None:
|
| 202 |
+
for label, etype in entities:
|
| 203 |
+
if self._graph.has_node(label):
|
| 204 |
+
existing = self._graph.nodes[label].get("chunk_ids", [])
|
| 205 |
+
if chunk.chunk_id not in existing:
|
| 206 |
+
existing.append(chunk.chunk_id)
|
| 207 |
+
self._graph.nodes[label]["chunk_ids"] = existing
|
| 208 |
+
else:
|
| 209 |
+
self._graph.add_node(
|
| 210 |
+
label,
|
| 211 |
+
entity_type=etype,
|
| 212 |
+
chunk_ids=[chunk.chunk_id],
|
| 213 |
+
source=chunk.source,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def _add_triples_to_graph(self, triples: list[Triple]) -> None:
|
| 217 |
+
for triple in triples:
|
| 218 |
+
for node in (triple.subject, triple.object):
|
| 219 |
+
if not self._graph.has_node(node):
|
| 220 |
+
self._graph.add_node(
|
| 221 |
+
node,
|
| 222 |
+
entity_type="UNKNOWN",
|
| 223 |
+
chunk_ids=[],
|
| 224 |
+
source=triple.source,
|
| 225 |
+
extractor=triple.extractor,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
if self._graph.has_edge(triple.subject, triple.object):
|
| 229 |
+
edge = self._graph[triple.subject][triple.object]
|
| 230 |
+
predicates = edge.get("predicates", [])
|
| 231 |
+
chunk_ids = edge.get("chunk_ids", [])
|
| 232 |
+
if triple.predicate not in predicates:
|
| 233 |
+
predicates.append(triple.predicate)
|
| 234 |
+
if triple.chunk_id not in chunk_ids:
|
| 235 |
+
chunk_ids.append(triple.chunk_id)
|
| 236 |
+
edge["predicates"] = predicates
|
| 237 |
+
edge["chunk_ids"] = chunk_ids
|
| 238 |
+
else:
|
| 239 |
+
self._graph.add_edge(
|
| 240 |
+
triple.subject, triple.object,
|
| 241 |
+
predicates=[triple.predicate],
|
| 242 |
+
chunk_ids=[triple.chunk_id],
|
| 243 |
+
source=triple.source,
|
| 244 |
+
extractor=triple.extractor,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# ββ Persistence βββββββββββββββββββββββββββββββββββββββββββ
|
| 248 |
+
|
| 249 |
+
def _load_if_exists(self) -> None:
|
| 250 |
+
if not self._path.exists():
|
| 251 |
+
return
|
| 252 |
+
try:
|
| 253 |
+
with open(self._path) as fh:
|
| 254 |
+
data = json.load(fh)
|
| 255 |
+
self._graph = nx.node_link_graph(data)
|
| 256 |
+
logger.info(
|
| 257 |
+
"Knowledge graph loaded: %d nodes, %d edges",
|
| 258 |
+
self._graph.number_of_nodes(),
|
| 259 |
+
self._graph.number_of_edges(),
|
| 260 |
+
)
|
| 261 |
+
except Exception as exc:
|
| 262 |
+
logger.warning("Failed to load graph (%s) β starting fresh.", exc)
|
| 263 |
+
|
| 264 |
+
# ββ spaCy βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 265 |
+
|
| 266 |
+
def _get_nlp(self):
|
| 267 |
+
if self._nlp is None:
|
| 268 |
+
try:
|
| 269 |
+
import spacy # type: ignore
|
| 270 |
+
except ImportError as exc:
|
| 271 |
+
raise RuntimeError("Install spacy: pip install spacy") from exc
|
| 272 |
+
try:
|
| 273 |
+
self._nlp = spacy.load("en_core_web_sm")
|
| 274 |
+
except OSError:
|
| 275 |
+
raise RuntimeError(
|
| 276 |
+
"Run: python -m spacy download en_core_web_sm"
|
| 277 |
+
)
|
| 278 |
+
return self._nlp
|
retrieval/graph_retriever.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Cortex RAG β Graph Retriever (Phase 3)
|
| 3 |
+
|
| 4 |
+
How multi-hop retrieval works
|
| 5 |
+
ββββββββββββββββββββββββββββββ
|
| 6 |
+
Standard dense retrieval can answer: "What is attention?"
|
| 7 |
+
It cannot answer: "Who wrote the attention paper, and what did they later
|
| 8 |
+
build that addresses memory bottlenecks in inference?"
|
| 9 |
+
|
| 10 |
+
That question requires:
|
| 11 |
+
Step 1: Find entity "Attention Is All You Need" in the graph
|
| 12 |
+
Step 2: Follow "authored_by" edges β Vaswani, Shazeer, Parmar, β¦
|
| 13 |
+
Step 3: Follow those author nodes' other edges β
|
| 14 |
+
Shazeer: "introduced" β "Multi-Query Attention"
|
| 15 |
+
Leviathan: "developed" β "Speculative Decoding"
|
| 16 |
+
Step 4: Collect all chunk_ids linked to visited nodes
|
| 17 |
+
Step 5: Fetch those chunks from Milvus β return to RRF pool
|
| 18 |
+
|
| 19 |
+
The BFS depth (default: 2 hops) is the key parameter. 1 hop = only
|
| 20 |
+
direct neighbours; 2 hops = neighbours of neighbours. 3+ hops tends to
|
| 21 |
+
explode in scope and include irrelevant context.
|
| 22 |
+
|
| 23 |
+
Entity matching
|
| 24 |
+
βββββββββββββββ
|
| 25 |
+
The query "Who developed PagedAttention?" must match graph nodes like
|
| 26 |
+
"Paged Attention" or "PagedAttention". We do:
|
| 27 |
+
1. Exact match (case-insensitive)
|
| 28 |
+
2. Partial match (query entity substring of node label)
|
| 29 |
+
3. spaCy NER on the query to extract candidate entity strings first
|
| 30 |
+
"""
|
| 31 |
+
from __future__ import annotations
|
| 32 |
+
|
| 33 |
+
import logging
|
| 34 |
+
from typing import Optional
|
| 35 |
+
|
| 36 |
+
from retrieval.dense import MilvusStore, RetrievedChunk
|
| 37 |
+
from retrieval.graph_builder import KnowledgeGraphBuilder
|
| 38 |
+
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class GraphRetriever:
|
| 43 |
+
"""
|
| 44 |
+
Retrieves chunks via knowledge graph traversal.
|
| 45 |
+
|
| 46 |
+
Returns RetrievedChunk objects fetched from Milvus, so they carry
|
| 47 |
+
the same structure as dense/BM25 results and can flow into RRF.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
graph_builder: Optional[KnowledgeGraphBuilder] = None,
|
| 53 |
+
store: Optional[MilvusStore] = None,
|
| 54 |
+
max_hops: int = 2,
|
| 55 |
+
) -> None:
|
| 56 |
+
self._builder = graph_builder or KnowledgeGraphBuilder()
|
| 57 |
+
self._store = store or MilvusStore()
|
| 58 |
+
self._max_hops = max_hops
|
| 59 |
+
self._nlp = None
|
| 60 |
+
|
| 61 |
+
# ββ Public API βββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
|
| 63 |
+
def search(self, query: str, top_k: int = 15) -> list[RetrievedChunk]:
|
| 64 |
+
"""
|
| 65 |
+
Graph traversal retrieval for a given query.
|
| 66 |
+
|
| 67 |
+
Pipeline:
|
| 68 |
+
1. Extract named entities from query (spaCy)
|
| 69 |
+
2. Anchor each entity to matching graph nodes (fuzzy match)
|
| 70 |
+
3. BFS up to max_hops from anchors
|
| 71 |
+
4. Collect chunk_ids from all visited nodes + traversed edges
|
| 72 |
+
5. Fetch chunks from Milvus by chunk_id
|
| 73 |
+
6. Score by graph centrality (number of graph links to query entities)
|
| 74 |
+
"""
|
| 75 |
+
G = self._builder.graph
|
| 76 |
+
if G.number_of_nodes() == 0:
|
| 77 |
+
logger.debug("Graph is empty β skipping graph retrieval.")
|
| 78 |
+
return []
|
| 79 |
+
|
| 80 |
+
# 1. Extract query entities
|
| 81 |
+
query_entities = self._extract_query_entities(query)
|
| 82 |
+
if not query_entities:
|
| 83 |
+
logger.debug("No named entities in query β skipping graph retrieval.")
|
| 84 |
+
return []
|
| 85 |
+
|
| 86 |
+
logger.debug("Graph query entities: %s", query_entities)
|
| 87 |
+
|
| 88 |
+
# 2. Find anchor nodes
|
| 89 |
+
anchor_nodes = self._find_anchor_nodes(query_entities, G)
|
| 90 |
+
if not anchor_nodes:
|
| 91 |
+
logger.debug("No anchor nodes found in graph.")
|
| 92 |
+
return []
|
| 93 |
+
|
| 94 |
+
logger.debug("Anchor nodes: %s", anchor_nodes)
|
| 95 |
+
|
| 96 |
+
# 3 + 4. BFS traversal β collect chunk_ids
|
| 97 |
+
chunk_id_scores: dict[str, float] = {}
|
| 98 |
+
visited_nodes: set[str] = set()
|
| 99 |
+
|
| 100 |
+
for anchor in anchor_nodes:
|
| 101 |
+
self._bfs_collect(
|
| 102 |
+
G, anchor, self._max_hops,
|
| 103 |
+
chunk_id_scores, visited_nodes
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
if not chunk_id_scores:
|
| 107 |
+
return []
|
| 108 |
+
|
| 109 |
+
# 5. Sort chunk_ids by score and fetch from Milvus
|
| 110 |
+
sorted_ids = sorted(
|
| 111 |
+
chunk_id_scores, key=lambda cid: chunk_id_scores[cid], reverse=True
|
| 112 |
+
)[:top_k]
|
| 113 |
+
|
| 114 |
+
chunks = self._fetch_chunks_from_milvus(sorted_ids, chunk_id_scores)
|
| 115 |
+
logger.info(
|
| 116 |
+
"Graph retriever: %d anchors, %d nodes visited, %d chunks returned",
|
| 117 |
+
len(anchor_nodes), len(visited_nodes), len(chunks)
|
| 118 |
+
)
|
| 119 |
+
return chunks
|
| 120 |
+
|
| 121 |
+
# ββ BFS traversal βββββββββββββββββββββββββββββββββββββββββ
|
| 122 |
+
|
| 123 |
+
def _bfs_collect(
|
| 124 |
+
self,
|
| 125 |
+
G,
|
| 126 |
+
start_node: str,
|
| 127 |
+
max_hops: int,
|
| 128 |
+
chunk_scores: dict[str, float],
|
| 129 |
+
visited: set[str],
|
| 130 |
+
) -> None:
|
| 131 |
+
"""
|
| 132 |
+
BFS from start_node up to max_hops.
|
| 133 |
+
Scores chunks by hop distance: 1.0 at hop 0, 0.5 at hop 1, 0.25 at hop 2.
|
| 134 |
+
"""
|
| 135 |
+
queue: list[tuple[str, int]] = [(start_node, 0)]
|
| 136 |
+
local_visited: set[str] = set()
|
| 137 |
+
|
| 138 |
+
while queue:
|
| 139 |
+
node, depth = queue.pop(0)
|
| 140 |
+
if node in local_visited or depth > max_hops:
|
| 141 |
+
continue
|
| 142 |
+
local_visited.add(node)
|
| 143 |
+
visited.add(node)
|
| 144 |
+
|
| 145 |
+
# Score = 1 / 2^depth (1.0 at anchor, 0.5 one hop away, etc.)
|
| 146 |
+
hop_score = 1.0 / (2 ** depth)
|
| 147 |
+
|
| 148 |
+
# Collect chunk_ids from this node
|
| 149 |
+
node_data = G.nodes[node]
|
| 150 |
+
for cid in node_data.get("chunk_ids", []):
|
| 151 |
+
chunk_scores[cid] = max(chunk_scores.get(cid, 0.0), hop_score)
|
| 152 |
+
|
| 153 |
+
# Collect chunk_ids from edges (relations)
|
| 154 |
+
for neighbour in G.neighbors(node):
|
| 155 |
+
edge_data = G[node][neighbour]
|
| 156 |
+
for cid in edge_data.get("chunk_ids", []):
|
| 157 |
+
chunk_scores[cid] = max(chunk_scores.get(cid, 0.0), hop_score * 0.8)
|
| 158 |
+
|
| 159 |
+
if depth < max_hops:
|
| 160 |
+
queue.append((neighbour, depth + 1))
|
| 161 |
+
|
| 162 |
+
# ββ Entity extraction ββββββββββββββββββββββββββββββββββββββ
|
| 163 |
+
|
| 164 |
+
def _extract_query_entities(self, query: str) -> list[str]:
|
| 165 |
+
"""
|
| 166 |
+
Extract named entities from the query using spaCy NER.
|
| 167 |
+
Falls back to noun chunks if NER finds nothing.
|
| 168 |
+
"""
|
| 169 |
+
try:
|
| 170 |
+
nlp = self._get_nlp()
|
| 171 |
+
doc = nlp(query)
|
| 172 |
+
entities = [ent.text.strip().title() for ent in doc.ents if len(ent.text.strip()) > 1]
|
| 173 |
+
if not entities:
|
| 174 |
+
# Fallback: try noun chunks (catches "attention mechanism", etc.)
|
| 175 |
+
entities = [
|
| 176 |
+
chunk.text.strip().title()
|
| 177 |
+
for chunk in doc.noun_chunks
|
| 178 |
+
if len(chunk.text.strip()) > 3
|
| 179 |
+
]
|
| 180 |
+
return entities
|
| 181 |
+
except Exception as exc:
|
| 182 |
+
logger.debug("Entity extraction failed: %s", exc)
|
| 183 |
+
return []
|
| 184 |
+
|
| 185 |
+
# ββ Node matching βββββββββββββββββββββββββββββββββββββββββ
|
| 186 |
+
|
| 187 |
+
@staticmethod
|
| 188 |
+
def _find_anchor_nodes(query_entities: list[str], G) -> list[str]:
|
| 189 |
+
"""
|
| 190 |
+
Find graph nodes that match query entities.
|
| 191 |
+
Priority: exact match β partial match.
|
| 192 |
+
"""
|
| 193 |
+
all_nodes = list(G.nodes())
|
| 194 |
+
lower_nodes = {n.lower(): n for n in all_nodes}
|
| 195 |
+
|
| 196 |
+
anchors: list[str] = []
|
| 197 |
+
for qe in query_entities:
|
| 198 |
+
qe_lower = qe.lower()
|
| 199 |
+
|
| 200 |
+
# Exact match (case-insensitive)
|
| 201 |
+
if qe_lower in lower_nodes:
|
| 202 |
+
anchors.append(lower_nodes[qe_lower])
|
| 203 |
+
continue
|
| 204 |
+
|
| 205 |
+
# Partial match: query entity is substring of a node label
|
| 206 |
+
for node_lower, node in lower_nodes.items():
|
| 207 |
+
if qe_lower in node_lower or node_lower in qe_lower:
|
| 208 |
+
if node not in anchors:
|
| 209 |
+
anchors.append(node)
|
| 210 |
+
|
| 211 |
+
return anchors[:10] # cap to avoid explosion on generic queries
|
| 212 |
+
|
| 213 |
+
# ββ Milvus fetch ββββββββββββββββββββββββββββββββββββββββββ
|
| 214 |
+
|
| 215 |
+
def _fetch_chunks_from_milvus(
|
| 216 |
+
self,
|
| 217 |
+
chunk_ids: list[str],
|
| 218 |
+
scores: dict[str, float],
|
| 219 |
+
) -> list[RetrievedChunk]:
|
| 220 |
+
"""
|
| 221 |
+
Fetch specific chunks from Milvus by chunk_id.
|
| 222 |
+
Tags each chunk with retriever="graph".
|
| 223 |
+
"""
|
| 224 |
+
if not chunk_ids:
|
| 225 |
+
return []
|
| 226 |
+
|
| 227 |
+
try:
|
| 228 |
+
# Milvus IN query
|
| 229 |
+
id_list = '", "'.join(chunk_ids)
|
| 230 |
+
expr = f'chunk_id in ["{id_list}"]'
|
| 231 |
+
|
| 232 |
+
coll = self._store._ensure_collection()
|
| 233 |
+
results = coll.query(
|
| 234 |
+
expr=expr,
|
| 235 |
+
output_fields=["chunk_id", "doc_id", "source", "title",
|
| 236 |
+
"text", "parent_text", "chunk_index"],
|
| 237 |
+
limit=len(chunk_ids),
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
chunks: list[RetrievedChunk] = []
|
| 241 |
+
for row in results:
|
| 242 |
+
cid = row["chunk_id"]
|
| 243 |
+
chunks.append(RetrievedChunk(
|
| 244 |
+
chunk_id=cid,
|
| 245 |
+
doc_id=row["doc_id"],
|
| 246 |
+
source=row["source"],
|
| 247 |
+
title=row["title"],
|
| 248 |
+
text=row["text"],
|
| 249 |
+
parent_text=row["parent_text"],
|
| 250 |
+
chunk_index=row["chunk_index"],
|
| 251 |
+
score=scores.get(cid, 0.1),
|
| 252 |
+
retriever="graph",
|
| 253 |
+
))
|
| 254 |
+
return sorted(chunks, key=lambda c: c.score, reverse=True)
|
| 255 |
+
|
| 256 |
+
except Exception as exc:
|
| 257 |
+
logger.warning("Milvus fetch for graph chunks failed: %s", exc)
|
| 258 |
+
return []
|
| 259 |
+
|
| 260 |
+
# ββ spaCy βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 261 |
+
|
| 262 |
+
def _get_nlp(self):
|
| 263 |
+
if self._nlp is None:
|
| 264 |
+
import spacy # type: ignore
|
| 265 |
+
try:
|
| 266 |
+
self._nlp = spacy.load("en_core_web_sm")
|
| 267 |
+
except OSError:
|
| 268 |
+
raise RuntimeError(
|
| 269 |
+
"Download spaCy model: python -m spacy download en_core_web_sm"
|
| 270 |
+
)
|
| 271 |
+
return self._nlp
|
retrieval/orchestrator.py
CHANGED
|
@@ -30,6 +30,8 @@ from retrieval.bm25 import BM25Retriever
|
|
| 30 |
from retrieval.dense import MilvusStore, RetrievedChunk
|
| 31 |
from retrieval.embedder import Embedder
|
| 32 |
from retrieval.fusion import CrossEncoderReranker, RRFFusion
|
|
|
|
|
|
|
| 33 |
from retrieval.router import QueryRouter, RoutingDecision
|
| 34 |
|
| 35 |
logger = logging.getLogger(__name__)
|
|
@@ -51,6 +53,7 @@ class MultiStrategyRetriever:
|
|
| 51 |
embedder: Optional[Embedder] = None,
|
| 52 |
store: Optional[MilvusStore] = None,
|
| 53 |
bm25: Optional[BM25Retriever] = None,
|
|
|
|
| 54 |
router: Optional[QueryRouter] = None,
|
| 55 |
fuser: Optional[RRFFusion] = None,
|
| 56 |
reranker: Optional[CrossEncoderReranker] = None,
|
|
@@ -58,6 +61,10 @@ class MultiStrategyRetriever:
|
|
| 58 |
self._embedder = embedder or Embedder()
|
| 59 |
self._dense = store or MilvusStore(embedder=self._embedder)
|
| 60 |
self._bm25 = bm25 or BM25Retriever()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
self._router = router or QueryRouter()
|
| 62 |
self._fuser = fuser or RRFFusion()
|
| 63 |
self._reranker = reranker or CrossEncoderReranker()
|
|
@@ -112,6 +119,17 @@ class MultiStrategyRetriever:
|
|
| 112 |
Dense indexing is handled separately by MilvusStore.
|
| 113 |
"""
|
| 114 |
return self._bm25.add_chunks(chunks)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
# ββ Private: parallel retrieval βββββββββββββββββββββββββββ
|
| 117 |
|
|
@@ -128,7 +146,7 @@ class MultiStrategyRetriever:
|
|
| 128 |
retriever_map = {
|
| 129 |
"dense": lambda q, k: self._dense.search(q, top_k=k),
|
| 130 |
"bm25": lambda q, k: self._bm25.search(q, top_k=k),
|
| 131 |
-
|
| 132 |
}
|
| 133 |
|
| 134 |
results: dict[str, list[RetrievedChunk]] = {}
|
|
|
|
| 30 |
from retrieval.dense import MilvusStore, RetrievedChunk
|
| 31 |
from retrieval.embedder import Embedder
|
| 32 |
from retrieval.fusion import CrossEncoderReranker, RRFFusion
|
| 33 |
+
from retrieval.graph_builder import KnowledgeGraphBuilder
|
| 34 |
+
from retrieval.graph_retriever import GraphRetriever
|
| 35 |
from retrieval.router import QueryRouter, RoutingDecision
|
| 36 |
|
| 37 |
logger = logging.getLogger(__name__)
|
|
|
|
| 53 |
embedder: Optional[Embedder] = None,
|
| 54 |
store: Optional[MilvusStore] = None,
|
| 55 |
bm25: Optional[BM25Retriever] = None,
|
| 56 |
+
graph_builder: Optional[KnowledgeGraphBuilder] = None,
|
| 57 |
router: Optional[QueryRouter] = None,
|
| 58 |
fuser: Optional[RRFFusion] = None,
|
| 59 |
reranker: Optional[CrossEncoderReranker] = None,
|
|
|
|
| 61 |
self._embedder = embedder or Embedder()
|
| 62 |
self._dense = store or MilvusStore(embedder=self._embedder)
|
| 63 |
self._bm25 = bm25 or BM25Retriever()
|
| 64 |
+
self._graph_builder = graph_builder or KnowledgeGraphBuilder()
|
| 65 |
+
self._graph = GraphRetriever(
|
| 66 |
+
graph_builder=self._graph_builder, store=self._dense
|
| 67 |
+
)
|
| 68 |
self._router = router or QueryRouter()
|
| 69 |
self._fuser = fuser or RRFFusion()
|
| 70 |
self._reranker = reranker or CrossEncoderReranker()
|
|
|
|
| 119 |
Dense indexing is handled separately by MilvusStore.
|
| 120 |
"""
|
| 121 |
return self._bm25.add_chunks(chunks)
|
| 122 |
+
|
| 123 |
+
def build_graph(self, chunks: list) -> dict:
|
| 124 |
+
"""
|
| 125 |
+
Extract entities + relations from chunks and update the knowledge graph.
|
| 126 |
+
Call from ingestion pipeline after dense + BM25 indexing.
|
| 127 |
+
"""
|
| 128 |
+
return self._graph_builder.process_chunks(chunks)
|
| 129 |
+
|
| 130 |
+
@property
|
| 131 |
+
def graph_builder(self) -> KnowledgeGraphBuilder:
|
| 132 |
+
return self._graph_builder
|
| 133 |
|
| 134 |
# ββ Private: parallel retrieval βββββββββββββββββββββββββββ
|
| 135 |
|
|
|
|
| 146 |
retriever_map = {
|
| 147 |
"dense": lambda q, k: self._dense.search(q, top_k=k),
|
| 148 |
"bm25": lambda q, k: self._bm25.search(q, top_k=k),
|
| 149 |
+
"graph": lambda q, k: self._graph.search(q, top_k=k),
|
| 150 |
}
|
| 151 |
|
| 152 |
results: dict[str, list[RetrievedChunk]] = {}
|
retrieval/relation_extractors.py
ADDED
|
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Cortex RAG β Relation Extractors
|
| 3 |
+
|
| 4 |
+
Strategy pattern: both extractors share the same interface.
|
| 5 |
+
Switch between them with GRAPH_EXTRACTOR=rebel|llm in .env.
|
| 6 |
+
|
| 7 |
+
RelationExtractor (abstract)
|
| 8 |
+
βββ REBELExtractor β local model, no API calls, Wikidata predicates
|
| 9 |
+
βββ LLMExtractor β Mistral/LLM, free-form predicates, rate-limited
|
| 10 |
+
|
| 11 |
+
KnowledgeGraphBuilder accepts either via dependency injection, or
|
| 12 |
+
auto-selects based on config.get_settings().graph_extractor.
|
| 13 |
+
|
| 14 |
+
Adding a new extractor in the future:
|
| 15 |
+
1. Subclass RelationExtractor
|
| 16 |
+
2. Implement extract(chunk) β list[Triple]
|
| 17 |
+
3. Register the name in _EXTRACTOR_REGISTRY at the bottom of this file
|
| 18 |
+
"""
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import json
|
| 22 |
+
import logging
|
| 23 |
+
import re
|
| 24 |
+
import time
|
| 25 |
+
from abc import ABC, abstractmethod
|
| 26 |
+
from dataclasses import dataclass
|
| 27 |
+
from typing import Optional
|
| 28 |
+
|
| 29 |
+
from config import get_settings
|
| 30 |
+
from ingestion.chunker import Chunk
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ββ Shared dataclass βββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class Triple:
|
| 39 |
+
subject: str
|
| 40 |
+
predicate: str
|
| 41 |
+
object: str
|
| 42 |
+
chunk_id: str
|
| 43 |
+
source: str
|
| 44 |
+
extractor: str = "unknown" # tracks which extractor produced this triple
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ββ Abstract base ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 48 |
+
|
| 49 |
+
class RelationExtractor(ABC):
|
| 50 |
+
"""
|
| 51 |
+
Common interface for all relation extraction strategies.
|
| 52 |
+
Subclasses must implement extract() only.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
@abstractmethod
|
| 56 |
+
def extract(self, chunk: Chunk) -> list[Triple]:
|
| 57 |
+
"""
|
| 58 |
+
Extract (subject, predicate, object) triples from a single chunk.
|
| 59 |
+
Must never raise β return [] on any failure.
|
| 60 |
+
"""
|
| 61 |
+
...
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
@abstractmethod
|
| 65 |
+
def name(self) -> str:
|
| 66 |
+
"""Short identifier used in logging and triple.extractor field."""
|
| 67 |
+
...
|
| 68 |
+
|
| 69 |
+
def extract_batch(self, chunks: list[Chunk]) -> dict[str, list[Triple]]:
|
| 70 |
+
"""
|
| 71 |
+
Extract triples from a list of chunks.
|
| 72 |
+
Default: calls extract() sequentially.
|
| 73 |
+
Subclasses can override for true batching (e.g. REBEL).
|
| 74 |
+
|
| 75 |
+
Returns: {chunk_id: [Triple, ...]}
|
| 76 |
+
"""
|
| 77 |
+
return {chunk.chunk_id: self.extract(chunk) for chunk in chunks}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# ββ REBEL extractor ββββββββββββββββββββββββββββββββββββββββββββ
|
| 81 |
+
|
| 82 |
+
# REBEL relation types that map cleanly to RAG-useful edges.
|
| 83 |
+
# The full Wikidata set has 220 types; we keep the ~40 most useful.
|
| 84 |
+
_REBEL_KEEP_RELATIONS = {
|
| 85 |
+
"author", "developer", "creator", "founded by", "owned by",
|
| 86 |
+
"instance of", "subclass of", "part of", "has part",
|
| 87 |
+
"country", "country of origin", "located in", "headquarters location",
|
| 88 |
+
"employer", "member of", "affiliation", "educated at",
|
| 89 |
+
"award received", "occupation", "field of work", "notable work",
|
| 90 |
+
"based on", "followed by", "follows", "influenced by", "has edition",
|
| 91 |
+
"product or material produced", "used by", "manufacturer",
|
| 92 |
+
"publication date", "academic degree", "applies to jurisdiction",
|
| 93 |
+
"published in", "platform", "programming language", "license",
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class REBELExtractor(RelationExtractor):
|
| 98 |
+
"""
|
| 99 |
+
Local relation extraction using REBEL (Babelscape/rebel-large).
|
| 100 |
+
|
| 101 |
+
Model facts:
|
| 102 |
+
- 406M params (BART-large fine-tuned on Wikipedia + Wikidata)
|
| 103 |
+
- Input: raw text sentence(s)
|
| 104 |
+
- Output: decoded triplet string β parsed into (head, type, tail)
|
| 105 |
+
- CPU inference: ~80β150ms per chunk on modern hardware
|
| 106 |
+
- No API calls, no rate limits, fully offline after first download
|
| 107 |
+
|
| 108 |
+
Batching:
|
| 109 |
+
REBEL's tokeniser handles variable-length batches natively.
|
| 110 |
+
extract_batch() sends all chunks in one forward pass, which is
|
| 111 |
+
significantly faster than calling extract() in a loop.
|
| 112 |
+
Max batch size is controlled by REBEL_BATCH_SIZE in config
|
| 113 |
+
(default 8 β safe for 8GB RAM; raise to 16β32 with more RAM).
|
| 114 |
+
|
| 115 |
+
Predicate normalisation:
|
| 116 |
+
REBEL outputs Wikidata relation labels (e.g. "country of origin").
|
| 117 |
+
We keep only relations in _REBEL_KEEP_RELATIONS (40 types) and
|
| 118 |
+
discard the rest β this prevents graph noise from obscure predicates
|
| 119 |
+
like "Wikimedia disambiguation page" or "image" polluting the graph.
|
| 120 |
+
|
| 121 |
+
Download:
|
| 122 |
+
First run downloads ~1.6GB to ~/.cache/huggingface/hub/.
|
| 123 |
+
Subsequent runs load from cache in ~3s.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
_REBEL_MODEL = "Babelscape/rebel-large"
|
| 127 |
+
_MAX_INPUT_TOKENS = 256 # REBEL was trained on short passages
|
| 128 |
+
_MAX_OUTPUT_TOKENS = 512
|
| 129 |
+
|
| 130 |
+
def __init__(self) -> None:
|
| 131 |
+
self._tokenizer = None
|
| 132 |
+
self._model = None
|
| 133 |
+
|
| 134 |
+
@property
|
| 135 |
+
def name(self) -> str:
|
| 136 |
+
return "rebel"
|
| 137 |
+
|
| 138 |
+
# ββ Public ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 139 |
+
|
| 140 |
+
def extract(self, chunk: Chunk) -> list[Triple]:
|
| 141 |
+
"""Single-chunk extraction (sequential). Prefer extract_batch for speed."""
|
| 142 |
+
results = self.extract_batch([chunk])
|
| 143 |
+
return results.get(chunk.chunk_id, [])
|
| 144 |
+
|
| 145 |
+
def extract_batch(self, chunks: list[Chunk]) -> dict[str, list[Triple]]:
|
| 146 |
+
"""
|
| 147 |
+
True batched extraction. All chunks processed in a single model call.
|
| 148 |
+
Falls back to sequential on memory errors.
|
| 149 |
+
"""
|
| 150 |
+
if not chunks:
|
| 151 |
+
return {}
|
| 152 |
+
|
| 153 |
+
tok, model = self._load()
|
| 154 |
+
cfg = get_settings()
|
| 155 |
+
batch_size = getattr(cfg, "rebel_batch_size", 8)
|
| 156 |
+
|
| 157 |
+
# Chunk text is truncated to avoid exceeding REBEL's context window
|
| 158 |
+
texts = [c.text[:1200] for c in chunks]
|
| 159 |
+
all_triples: dict[str, list[Triple]] = {c.chunk_id: [] for c in chunks}
|
| 160 |
+
|
| 161 |
+
for batch_start in range(0, len(chunks), batch_size):
|
| 162 |
+
batch_chunks = chunks[batch_start : batch_start + batch_size]
|
| 163 |
+
batch_texts = texts[batch_start : batch_start + batch_size]
|
| 164 |
+
|
| 165 |
+
try:
|
| 166 |
+
inputs = tok(
|
| 167 |
+
batch_texts,
|
| 168 |
+
max_length=self._MAX_INPUT_TOKENS,
|
| 169 |
+
padding=True,
|
| 170 |
+
truncation=True,
|
| 171 |
+
return_tensors="pt",
|
| 172 |
+
)
|
| 173 |
+
generated = model.generate(
|
| 174 |
+
**inputs,
|
| 175 |
+
max_length=self._MAX_OUTPUT_TOKENS,
|
| 176 |
+
num_beams=3,
|
| 177 |
+
early_stopping=True,
|
| 178 |
+
)
|
| 179 |
+
decoded = tok.batch_decode(generated, skip_special_tokens=False)
|
| 180 |
+
|
| 181 |
+
for chunk, raw_output in zip(batch_chunks, decoded):
|
| 182 |
+
triples = self._parse_rebel_output(raw_output, chunk)
|
| 183 |
+
all_triples[chunk.chunk_id] = triples
|
| 184 |
+
|
| 185 |
+
except Exception as exc:
|
| 186 |
+
logger.warning("REBEL batch %d failed: %s", batch_start, exc)
|
| 187 |
+
# Mark as empty rather than crashing the whole ingestion
|
| 188 |
+
for chunk in batch_chunks:
|
| 189 |
+
all_triples[chunk.chunk_id] = []
|
| 190 |
+
|
| 191 |
+
return all_triples
|
| 192 |
+
|
| 193 |
+
# ββ REBEL output parser ββββββββββββββββββββββββββββββββββββ
|
| 194 |
+
|
| 195 |
+
def _parse_rebel_output(self, decoded: str, chunk: Chunk) -> list[Triple]:
|
| 196 |
+
"""
|
| 197 |
+
Parse REBEL's special-token output format.
|
| 198 |
+
|
| 199 |
+
REBEL outputs a string like:
|
| 200 |
+
<triplet> Vaswani <subj> Attention Is All You Need <obj> author
|
| 201 |
+
<triplet> Transformer <subj> NLP <obj> field of work
|
| 202 |
+
|
| 203 |
+
We extract each triplet, filter to keep relations, normalise,
|
| 204 |
+
and return Triple dataclasses.
|
| 205 |
+
"""
|
| 206 |
+
triples: list[Triple] = []
|
| 207 |
+
|
| 208 |
+
# Split on <triplet> delimiter
|
| 209 |
+
raw_triplets = decoded.split("<triplet>")
|
| 210 |
+
|
| 211 |
+
for raw in raw_triplets:
|
| 212 |
+
raw = raw.strip()
|
| 213 |
+
if not raw or "<subj>" not in raw or "<obj>" not in raw:
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
try:
|
| 217 |
+
# Format: "SUBJECT <subj> OBJECT <obj> RELATION"
|
| 218 |
+
subj_split = raw.split("<subj>")
|
| 219 |
+
subject = subj_split[0].strip()
|
| 220 |
+
|
| 221 |
+
obj_rel = subj_split[1].split("<obj>")
|
| 222 |
+
obj = obj_rel[0].strip()
|
| 223 |
+
relation = obj_rel[1].strip()
|
| 224 |
+
|
| 225 |
+
# Clean up any residual special tokens
|
| 226 |
+
for tok_str in ["</s>", "<s>", "<pad>"]:
|
| 227 |
+
relation = relation.replace(tok_str, "").strip()
|
| 228 |
+
subject = subject.replace(tok_str, "").strip()
|
| 229 |
+
obj = obj.replace(tok_str, "").strip()
|
| 230 |
+
|
| 231 |
+
if not subject or not obj or not relation:
|
| 232 |
+
continue
|
| 233 |
+
|
| 234 |
+
# Filter to useful relation types only
|
| 235 |
+
if relation.lower() not in _REBEL_KEEP_RELATIONS:
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
triples.append(Triple(
|
| 239 |
+
subject=subject.title(),
|
| 240 |
+
predicate=relation.lower(),
|
| 241 |
+
object=obj.title(),
|
| 242 |
+
chunk_id=chunk.chunk_id,
|
| 243 |
+
source=chunk.source,
|
| 244 |
+
extractor=self.name,
|
| 245 |
+
))
|
| 246 |
+
|
| 247 |
+
except (IndexError, AttributeError):
|
| 248 |
+
continue
|
| 249 |
+
|
| 250 |
+
return triples[:8] # cap per chunk
|
| 251 |
+
|
| 252 |
+
# ββ Model loading ββββββββββββββββββββββββββββββββββββββββββ
|
| 253 |
+
|
| 254 |
+
def _load(self):
|
| 255 |
+
if self._tokenizer is None or self._model is None:
|
| 256 |
+
try:
|
| 257 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # type: ignore
|
| 258 |
+
except ImportError as exc:
|
| 259 |
+
raise RuntimeError(
|
| 260 |
+
"Install transformers: pip install transformers"
|
| 261 |
+
) from exc
|
| 262 |
+
|
| 263 |
+
logger.info("Loading REBEL model '%s' (first run downloads ~1.6GB)β¦", self._REBEL_MODEL)
|
| 264 |
+
t0 = time.perf_counter()
|
| 265 |
+
self._tokenizer = AutoTokenizer.from_pretrained(self._REBEL_MODEL)
|
| 266 |
+
self._model = AutoModelForSeq2SeqLM.from_pretrained(self._REBEL_MODEL)
|
| 267 |
+
self._model.eval() # inference mode β disables dropout
|
| 268 |
+
logger.info("REBEL loaded in %.1fs", time.perf_counter() - t0)
|
| 269 |
+
|
| 270 |
+
return self._tokenizer, self._model
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# ββ LLM extractor (original method, preserved) βββββββββββββββββ
|
| 274 |
+
|
| 275 |
+
_LLM_PROMPT = """\
|
| 276 |
+
Extract factual relationships from the passage below.
|
| 277 |
+
Return ONLY a JSON array of triples. Each triple is:
|
| 278 |
+
{{"subject": "...", "predicate": "...", "object": "..."}}
|
| 279 |
+
|
| 280 |
+
Rules:
|
| 281 |
+
- subject and object must be named entities (people, orgs, systems, concepts)
|
| 282 |
+
- predicate is a short verb phrase ("developed", "is based on", "introduced", "authored")
|
| 283 |
+
- Extract 0β5 triples maximum. If there are none, return []
|
| 284 |
+
- Return ONLY the JSON array, no explanation, no markdown
|
| 285 |
+
|
| 286 |
+
Passage:
|
| 287 |
+
{text}
|
| 288 |
+
"""
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class LLMExtractor(RelationExtractor):
|
| 292 |
+
"""
|
| 293 |
+
Relation extraction via Mistral LLM (the original Phase 3 method).
|
| 294 |
+
|
| 295 |
+
Produces free-form, human-readable predicates ("introduced the concept of",
|
| 296 |
+
"co-authored with") rather than the fixed Wikidata vocabulary that REBEL uses.
|
| 297 |
+
|
| 298 |
+
Use this when:
|
| 299 |
+
- You want rich, domain-specific predicate labels
|
| 300 |
+
- Your corpus is small enough that rate limits aren't a problem
|
| 301 |
+
- You want to fine-tune the extraction prompt for your specific domain
|
| 302 |
+
|
| 303 |
+
Rate limiting:
|
| 304 |
+
Set MISTRAL_RELATION_RPM in .env to cap requests-per-minute.
|
| 305 |
+
Default is 0 (no cap). Mistral free tier allows ~30 RPM.
|
| 306 |
+
"""
|
| 307 |
+
|
| 308 |
+
def __init__(self) -> None:
|
| 309 |
+
self._llm = None
|
| 310 |
+
|
| 311 |
+
@property
|
| 312 |
+
def name(self) -> str:
|
| 313 |
+
return "llm"
|
| 314 |
+
|
| 315 |
+
def extract(self, chunk: Chunk) -> list[Triple]:
|
| 316 |
+
try:
|
| 317 |
+
client = self._get_llm()
|
| 318 |
+
cfg = get_settings()
|
| 319 |
+
|
| 320 |
+
if cfg.llm_server == "ollama":
|
| 321 |
+
response = client.chat.complete(
|
| 322 |
+
model=cfg.ollama_model,
|
| 323 |
+
messages=[{
|
| 324 |
+
"role": "user",
|
| 325 |
+
"content": _LLM_PROMPT.format(text=chunk.text[:2000]),
|
| 326 |
+
}],
|
| 327 |
+
)
|
| 328 |
+
else:
|
| 329 |
+
response = client.chat.complete(
|
| 330 |
+
model=cfg.mistral_model,
|
| 331 |
+
messages=[{
|
| 332 |
+
"role": "user",
|
| 333 |
+
"content": _LLM_PROMPT.format(text=chunk.text[:2000]),
|
| 334 |
+
}],
|
| 335 |
+
temperature=0.0,
|
| 336 |
+
max_tokens=512,
|
| 337 |
+
)
|
| 338 |
+
raw = response.choices[0].message.content or "[]"
|
| 339 |
+
return self._parse(raw, chunk)
|
| 340 |
+
|
| 341 |
+
except Exception as exc:
|
| 342 |
+
logger.debug("LLM extractor failed for chunk %s: %s", chunk.chunk_id, exc)
|
| 343 |
+
return []
|
| 344 |
+
|
| 345 |
+
def _parse(self, raw: str, chunk: Chunk) -> list[Triple]:
|
| 346 |
+
raw = raw.strip()
|
| 347 |
+
if raw.startswith("```"):
|
| 348 |
+
raw = re.sub(r"^```[a-z]*\n?", "", raw)
|
| 349 |
+
raw = re.sub(r"\n?```$", "", raw)
|
| 350 |
+
try:
|
| 351 |
+
items = json.loads(raw)
|
| 352 |
+
except json.JSONDecodeError:
|
| 353 |
+
return []
|
| 354 |
+
|
| 355 |
+
triples: list[Triple] = []
|
| 356 |
+
for item in items[:5]:
|
| 357 |
+
if not isinstance(item, dict):
|
| 358 |
+
continue
|
| 359 |
+
s = str(item.get("subject", "")).strip()
|
| 360 |
+
p = str(item.get("predicate", "")).strip()
|
| 361 |
+
o = str(item.get("object", "")).strip()
|
| 362 |
+
if s and p and o:
|
| 363 |
+
triples.append(Triple(
|
| 364 |
+
subject=s.title(),
|
| 365 |
+
predicate=p.lower(),
|
| 366 |
+
object=o.title(),
|
| 367 |
+
chunk_id=chunk.chunk_id,
|
| 368 |
+
source=chunk.source,
|
| 369 |
+
extractor=self.name,
|
| 370 |
+
))
|
| 371 |
+
return triples
|
| 372 |
+
|
| 373 |
+
def _get_llm(self):
|
| 374 |
+
if self._llm is None:
|
| 375 |
+
cfg = get_settings()
|
| 376 |
+
llm_server = cfg.llm_server
|
| 377 |
+
|
| 378 |
+
if llm_server == "ollama":
|
| 379 |
+
try:
|
| 380 |
+
from ollama import Client as ollama_client # type: ignore
|
| 381 |
+
except ImportError as exc:
|
| 382 |
+
raise RuntimeError(
|
| 383 |
+
"Install ollama client: pip install ollama"
|
| 384 |
+
) from exc
|
| 385 |
+
self._llm = ollama_client(host=cfg.ollama_host)
|
| 386 |
+
else:
|
| 387 |
+
if not cfg.mistral_api_key:
|
| 388 |
+
raise RuntimeError("MISTRAL_API_KEY not set")
|
| 389 |
+
from mistralai.client import Mistral # type: ignore
|
| 390 |
+
self._llm = Mistral(api_key=cfg.mistral_api_key)
|
| 391 |
+
return self._llm
|
| 392 |
+
|
| 393 |
+
# ββ Entity density filter (Option 4) ββββββββββββββββββββββββββ
|
| 394 |
+
|
| 395 |
+
class EntityDensityFilter(RelationExtractor):
|
| 396 |
+
"""
|
| 397 |
+
Decorator that wraps any extractor and skips low-entity-density chunks.
|
| 398 |
+
|
| 399 |
+
Rationale
|
| 400 |
+
βββββββββ
|
| 401 |
+
Chunks with 0β1 named entities rarely yield useful triples β a
|
| 402 |
+
paragraph of methodology boilerplate has no entities to link.
|
| 403 |
+
Scoring by entity density (entities per 100 tokens) and processing
|
| 404 |
+
only the top N% of chunks cuts extraction time by ~70% with
|
| 405 |
+
negligible graph quality loss.
|
| 406 |
+
|
| 407 |
+
How density is computed
|
| 408 |
+
βββββββββββββββββββββββ
|
| 409 |
+
density = (spaCy NER entity count) / (token count / 100)
|
| 410 |
+
|
| 411 |
+
This normalises for chunk length β a 50-token chunk with 3 entities
|
| 412 |
+
scores higher than a 500-token chunk with the same 3 entities.
|
| 413 |
+
|
| 414 |
+
Usage
|
| 415 |
+
βββββ
|
| 416 |
+
# Wrap REBEL, keep top 30% of chunks (default):
|
| 417 |
+
extractor = EntityDensityFilter(REBELExtractor())
|
| 418 |
+
|
| 419 |
+
# Wrap LLM, keep top 20%, only chunks with β₯2 entities:
|
| 420 |
+
extractor = EntityDensityFilter(
|
| 421 |
+
LLMExtractor(),
|
| 422 |
+
top_fraction=0.20,
|
| 423 |
+
min_entity_count=2,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# Via config (wraps whatever GRAPH_EXTRACTOR is set to):
|
| 427 |
+
GRAPH_EXTRACTOR=rebel-filtered # rebel + density filter
|
| 428 |
+
GRAPH_EXTRACTOR=llm-filtered # llm + density filter
|
| 429 |
+
"""
|
| 430 |
+
|
| 431 |
+
def __init__(
|
| 432 |
+
self,
|
| 433 |
+
inner: RelationExtractor,
|
| 434 |
+
top_fraction: Optional[float] = None,
|
| 435 |
+
min_entity_count: Optional[int] = None,
|
| 436 |
+
) -> None:
|
| 437 |
+
cfg = get_settings()
|
| 438 |
+
self._inner = inner
|
| 439 |
+
# top_fraction: process only the top X% most entity-dense chunks
|
| 440 |
+
self._top_fraction = top_fraction or getattr(cfg, "density_top_fraction", 0.30)
|
| 441 |
+
# min_entity_count: hard floor β never extract from chunks below this
|
| 442 |
+
self._min_entity_count = min_entity_count or getattr(cfg, "density_min_entities", 2)
|
| 443 |
+
self._nlp = None
|
| 444 |
+
|
| 445 |
+
@property
|
| 446 |
+
def name(self) -> str:
|
| 447 |
+
return f"{self._inner.name}-filtered"
|
| 448 |
+
|
| 449 |
+
# ββ Public ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 450 |
+
|
| 451 |
+
def extract(self, chunk: Chunk) -> list[Triple]:
|
| 452 |
+
"""Single-chunk extraction with density pre-check."""
|
| 453 |
+
if not self._passes_density_check([chunk]):
|
| 454 |
+
logger.debug("Chunk %s skipped (low entity density)", chunk.chunk_id)
|
| 455 |
+
return []
|
| 456 |
+
return self._inner.extract(chunk)
|
| 457 |
+
|
| 458 |
+
def extract_batch(self, chunks: list[Chunk]) -> dict[str, list[Triple]]:
|
| 459 |
+
"""
|
| 460 |
+
Filter chunks by density score, then delegate only the qualifying
|
| 461 |
+
subset to the inner extractor's batch method.
|
| 462 |
+
|
| 463 |
+
Steps:
|
| 464 |
+
1. Score every chunk by entity density (fast β pure spaCy)
|
| 465 |
+
2. Apply min_entity_count hard floor
|
| 466 |
+
3. Keep top_fraction of remaining chunks by density score
|
| 467 |
+
4. Pass filtered set to inner.extract_batch()
|
| 468 |
+
5. Return merged result (skipped chunks β empty list)
|
| 469 |
+
"""
|
| 470 |
+
if not chunks:
|
| 471 |
+
return {}
|
| 472 |
+
|
| 473 |
+
# Score all chunks
|
| 474 |
+
scored = self._score_chunks(chunks) # list of (chunk, density, entity_count)
|
| 475 |
+
|
| 476 |
+
# Hard floor: drop chunks below minimum entity count
|
| 477 |
+
above_floor = [(c, d, n) for c, d, n in scored if n >= self._min_entity_count]
|
| 478 |
+
|
| 479 |
+
# Top-fraction cut: sort by density desc, keep top N%
|
| 480 |
+
above_floor.sort(key=lambda x: x[1], reverse=True)
|
| 481 |
+
cutoff = max(1, int(len(above_floor) * self._top_fraction))
|
| 482 |
+
selected = [c for c, _, _ in above_floor[:cutoff]]
|
| 483 |
+
skipped = len(chunks) - len(selected)
|
| 484 |
+
|
| 485 |
+
if skipped:
|
| 486 |
+
logger.info(
|
| 487 |
+
"Density filter: %d/%d chunks selected (top %.0f%%, min_entities=%d)",
|
| 488 |
+
len(selected), len(chunks),
|
| 489 |
+
self._top_fraction * 100, self._min_entity_count,
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
# Delegate to inner extractor
|
| 493 |
+
if not selected:
|
| 494 |
+
return {c.chunk_id: [] for c in chunks}
|
| 495 |
+
|
| 496 |
+
inner_results = self._inner.extract_batch(selected)
|
| 497 |
+
|
| 498 |
+
# Merge: unselected chunks get empty lists
|
| 499 |
+
selected_ids = {c.chunk_id for c in selected}
|
| 500 |
+
return {
|
| 501 |
+
c.chunk_id: inner_results.get(c.chunk_id, []) if c.chunk_id in selected_ids else []
|
| 502 |
+
for c in chunks
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
# ββ Density scoring ββββββββββββββββββββββββββββββββββββββββ
|
| 506 |
+
|
| 507 |
+
def _score_chunks(
|
| 508 |
+
self, chunks: list[Chunk]
|
| 509 |
+
) -> list[tuple[Chunk, float, int]]:
|
| 510 |
+
"""
|
| 511 |
+
Returns list of (chunk, density_score, entity_count).
|
| 512 |
+
density_score = entities per 100 tokens (approx).
|
| 513 |
+
"""
|
| 514 |
+
nlp = self._get_nlp()
|
| 515 |
+
results = []
|
| 516 |
+
for chunk in chunks:
|
| 517 |
+
doc = nlp(chunk.text[:5000])
|
| 518 |
+
entity_count = len([e for e in doc.ents if len(e.text.strip()) > 1])
|
| 519 |
+
token_count = max(len(doc), 1)
|
| 520 |
+
density = (entity_count / token_count) * 100
|
| 521 |
+
results.append((chunk, density, entity_count))
|
| 522 |
+
return results
|
| 523 |
+
|
| 524 |
+
def _passes_density_check(self, chunks: list[Chunk]) -> bool:
|
| 525 |
+
"""Quick single-chunk density check for extract()."""
|
| 526 |
+
if not chunks:
|
| 527 |
+
return False
|
| 528 |
+
_, _, entity_count = self._score_chunks(chunks)[0]
|
| 529 |
+
return entity_count >= self._min_entity_count
|
| 530 |
+
|
| 531 |
+
# ββ spaCy ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 532 |
+
|
| 533 |
+
def _get_nlp(self):
|
| 534 |
+
if self._nlp is None:
|
| 535 |
+
import spacy # type: ignore
|
| 536 |
+
try:
|
| 537 |
+
self._nlp = spacy.load("en_core_web_sm")
|
| 538 |
+
except OSError:
|
| 539 |
+
raise RuntimeError("Run: python -m spacy download en_core_web_sm")
|
| 540 |
+
return self._nlp
|
| 541 |
+
|
| 542 |
+
# ββ Registry + factory βββββββββββββββββββββββββββββββββββββββββ
|
| 543 |
+
|
| 544 |
+
_EXTRACTOR_REGISTRY: dict[str, type[RelationExtractor]] = {
|
| 545 |
+
"rebel": REBELExtractor,
|
| 546 |
+
"llm": LLMExtractor,
|
| 547 |
+
# Density-filtered variants are constructed specially β see build_extractor()
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
# Names that trigger density-filter wrapping
|
| 551 |
+
_FILTERED_VARIANTS = {
|
| 552 |
+
"rebel-filtered": "rebel",
|
| 553 |
+
"llm-filtered": "llm",
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def build_extractor(name: Optional[str] = None) -> RelationExtractor:
|
| 558 |
+
"""
|
| 559 |
+
Available values for GRAPH_EXTRACTOR:
|
| 560 |
+
"rebel" β REBEL local model, no API calls (default)
|
| 561 |
+
"llm" β Groq LLM, free-form predicates
|
| 562 |
+
"rebel-filtered" β REBEL + entity density pre-filter (option 4)
|
| 563 |
+
"llm-filtered" β LLM + entity density pre-filter (option 4)
|
| 564 |
+
|
| 565 |
+
Explicit usage in code:
|
| 566 |
+
extractor = build_extractor("rebel-filtered")
|
| 567 |
+
|
| 568 |
+
# Or compose manually for full control:
|
| 569 |
+
extractor = EntityDensityFilter(
|
| 570 |
+
REBELExtractor(),
|
| 571 |
+
top_fraction=0.25,
|
| 572 |
+
min_entity_count=3,
|
| 573 |
+
)
|
| 574 |
+
"""
|
| 575 |
+
cfg = get_settings()
|
| 576 |
+
extractor_name = (name or getattr(cfg, "graph_extractor", "rebel")).lower()
|
| 577 |
+
|
| 578 |
+
# Density-filtered variant: build inner extractor then wrap it
|
| 579 |
+
if extractor_name in _FILTERED_VARIANTS:
|
| 580 |
+
inner_name = _FILTERED_VARIANTS[extractor_name]
|
| 581 |
+
inner_cls = _EXTRACTOR_REGISTRY[inner_name]
|
| 582 |
+
inner = inner_cls()
|
| 583 |
+
logger.info(
|
| 584 |
+
"Using relation extractor: %s (inner=%s, top_fraction=%.0f%%, min_entities=%d)",
|
| 585 |
+
extractor_name, inner_name,
|
| 586 |
+
getattr(cfg, "density_top_fraction", 0.30) * 100,
|
| 587 |
+
getattr(cfg, "density_min_entities", 2),
|
| 588 |
+
)
|
| 589 |
+
return EntityDensityFilter(inner)
|
| 590 |
+
|
| 591 |
+
# Plain extractor
|
| 592 |
+
cls = _EXTRACTOR_REGISTRY.get(extractor_name)
|
| 593 |
+
if cls is None:
|
| 594 |
+
available = list(_EXTRACTOR_REGISTRY.keys()) + list(_FILTERED_VARIANTS.keys())
|
| 595 |
+
raise ValueError(
|
| 596 |
+
f"Unknown extractor '{extractor_name}'. "
|
| 597 |
+
f"Available: {available}. "
|
| 598 |
+
f"Set GRAPH_EXTRACTOR in .env to one of these."
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
logger.info("Using relation extractor: %s", extractor_name)
|
| 602 |
+
return cls()
|
ui/app.py
CHANGED
|
@@ -12,12 +12,18 @@ import json
|
|
| 12 |
import time
|
| 13 |
from pathlib import Path
|
| 14 |
from typing import Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
import requests
|
| 17 |
import streamlit as st
|
| 18 |
|
| 19 |
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
-
|
|
|
|
|
|
|
| 21 |
|
| 22 |
st.set_page_config(
|
| 23 |
page_title="Cortex RAG",
|
|
@@ -66,11 +72,24 @@ def _render_source_cards_raw(chunks: list[dict]):
|
|
| 66 |
title = chunk.get("title", "Unknown")
|
| 67 |
source = Path(chunk.get("source", "")).name
|
| 68 |
snippet = chunk.get("text_snippet", "")[:160]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
st.markdown(f"""
|
| 70 |
<div class="source-card">
|
| 71 |
<strong>[{i+1}] {title}</strong>
|
| 72 |
<span class="score-badge" style="float:right">{score_pct}%</span><br/>
|
| 73 |
-
<small style="color:#6b7280">{source}</small>
|
|
|
|
| 74 |
<div class="chunk-snippet">{snippet}β¦</div>
|
| 75 |
</div>""", unsafe_allow_html=True)
|
| 76 |
|
|
@@ -101,7 +120,7 @@ st.markdown(
|
|
| 101 |
)
|
| 102 |
st.divider()
|
| 103 |
|
| 104 |
-
tab_ask, tab_ingest, tab_system = st.tabs(["π Ask", "π₯ Ingest", "π©Ί System"])
|
| 105 |
|
| 106 |
|
| 107 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -173,6 +192,19 @@ with tab_ask:
|
|
| 173 |
sources_placeholder.markdown(payload.get("text", ""))
|
| 174 |
status_placeholder.empty()
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
elif event_type == "done":
|
| 177 |
answer_placeholder.markdown(answer_text)
|
| 178 |
status_placeholder.empty()
|
|
@@ -262,7 +294,129 @@ with tab_ingest:
|
|
| 262 |
|
| 263 |
|
| 264 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 265 |
-
# TAB 3 β
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 267 |
with tab_system:
|
| 268 |
st.subheader("System health")
|
|
@@ -299,5 +453,13 @@ with tab_system:
|
|
| 299 |
st.metric("Chunks indexed", stats.get("entity_count", "β"))
|
| 300 |
|
| 301 |
st.divider()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
st.markdown("**Raw health response**")
|
| 303 |
st.json(health)
|
|
|
|
| 12 |
import time
|
| 13 |
from pathlib import Path
|
| 14 |
from typing import Optional
|
| 15 |
+
import sys
|
| 16 |
+
|
| 17 |
+
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
| 18 |
+
from config import get_settings
|
| 19 |
|
| 20 |
import requests
|
| 21 |
import streamlit as st
|
| 22 |
|
| 23 |
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
+
cfg = get_settings()
|
| 25 |
+
API_BASE = f"http://{cfg.api_host}:{cfg.api_port}"
|
| 26 |
+
REDIS_URL = cfg.redis_url
|
| 27 |
|
| 28 |
st.set_page_config(
|
| 29 |
page_title="Cortex RAG",
|
|
|
|
| 72 |
title = chunk.get("title", "Unknown")
|
| 73 |
source = Path(chunk.get("source", "")).name
|
| 74 |
snippet = chunk.get("text_snippet", "")[:160]
|
| 75 |
+
retriever = chunk.get("retriever", "dense")
|
| 76 |
+
retriever_colors = {
|
| 77 |
+
"dense": "#dbeafe:#1e40af",
|
| 78 |
+
"bm25": "#dcfce7:#166534",
|
| 79 |
+
"dense+bm25": "#f3e8ff:#6b21a8",
|
| 80 |
+
"bm25+dense": "#f3e8ff:#6b21a8",
|
| 81 |
+
"graph": "#fef9c3:#854d0e",
|
| 82 |
+
"web_search": "#fee2e2:#991b1b",
|
| 83 |
+
}
|
| 84 |
+
ret_style = retriever_colors.get(retriever, "#f3f4f6:#374151")
|
| 85 |
+
ret_bg, ret_fg = ret_style.split(":")
|
| 86 |
+
|
| 87 |
st.markdown(f"""
|
| 88 |
<div class="source-card">
|
| 89 |
<strong>[{i+1}] {title}</strong>
|
| 90 |
<span class="score-badge" style="float:right">{score_pct}%</span><br/>
|
| 91 |
+
<small style="color:#6b7280">{source}</small>
|
| 92 |
+
<span style="background:{ret_bg};color:{ret_fg};border-radius:4px;padding:1px 6px;font-size:0.72rem;font-weight:600">{retriever}</span>
|
| 93 |
<div class="chunk-snippet">{snippet}β¦</div>
|
| 94 |
</div>""", unsafe_allow_html=True)
|
| 95 |
|
|
|
|
| 120 |
)
|
| 121 |
st.divider()
|
| 122 |
|
| 123 |
+
tab_ask, tab_ingest, tab_eval, tab_system = st.tabs(["π Ask", "π₯ Ingest", "π Evaluation", "π©Ί System"])
|
| 124 |
|
| 125 |
|
| 126 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 192 |
sources_placeholder.markdown(payload.get("text", ""))
|
| 193 |
status_placeholder.empty()
|
| 194 |
|
| 195 |
+
elif event_type == "crag_update":
|
| 196 |
+
grade = payload.get("grade", "")
|
| 197 |
+
rewritten = payload.get("rewritten_query")
|
| 198 |
+
web_used = payload.get("web_search_used", False)
|
| 199 |
+
reasoning = payload.get("reasoning", "")
|
| 200 |
+
icon = {"POOR": "π", "ABSENT": "π"}.get(grade, "βΉοΈ")
|
| 201 |
+
msg = f"{icon} **CRAG {grade}**: {reasoning[:100]}"
|
| 202 |
+
if rewritten:
|
| 203 |
+
msg += " \n\u21a9 Rewritten: *" + rewritten + "*"
|
| 204 |
+
if web_used:
|
| 205 |
+
msg += " \n\U0001f310 Web search fallback used"
|
| 206 |
+
status_placeholder.info(msg)
|
| 207 |
+
|
| 208 |
elif event_type == "done":
|
| 209 |
answer_placeholder.markdown(answer_text)
|
| 210 |
status_placeholder.empty()
|
|
|
|
| 294 |
|
| 295 |
|
| 296 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 297 |
+
# TAB 3 β EVALUATION DASHBOARD
|
| 298 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 299 |
+
with tab_eval:
|
| 300 |
+
st.subheader("RAG evaluation dashboard")
|
| 301 |
+
st.caption("Metrics update automatically after each query. RAGAS scores compute in the background (~5s after response).")
|
| 302 |
+
|
| 303 |
+
if st.button("π Refresh metrics"):
|
| 304 |
+
st.session_state.pop("metrics_data", None)
|
| 305 |
+
|
| 306 |
+
if "metrics_data" not in st.session_state:
|
| 307 |
+
try:
|
| 308 |
+
resp = requests.get(f"{API_BASE}/metrics?limit=200&days=14", timeout=5)
|
| 309 |
+
resp.raise_for_status()
|
| 310 |
+
st.session_state.metrics_data = resp.json()
|
| 311 |
+
except Exception as exc:
|
| 312 |
+
st.session_state.metrics_data = {"error": str(exc)}
|
| 313 |
+
|
| 314 |
+
mdata = st.session_state.get("metrics_data", {})
|
| 315 |
+
|
| 316 |
+
if "error" in mdata:
|
| 317 |
+
st.error(f"Cannot reach API: {mdata['error']}")
|
| 318 |
+
else:
|
| 319 |
+
summary = mdata.get("summary", {})
|
| 320 |
+
cache = mdata.get("cache", {})
|
| 321 |
+
|
| 322 |
+
# ββ Header KPI row βββββββββββββββββββββββββββββββββββββ
|
| 323 |
+
k1, k2, k3, k4, k5, k6 = st.columns(6)
|
| 324 |
+
k1.metric("Total queries", summary.get("total_queries", 0))
|
| 325 |
+
k2.metric("Faithfulness", f"{summary.get('avg_faithfulness', 0):.2f}")
|
| 326 |
+
k3.metric("Answer relevancy", f"{summary.get('avg_answer_relevancy', 0):.2f}")
|
| 327 |
+
k4.metric("Context precision",f"{summary.get('avg_context_precision', 0):.2f}")
|
| 328 |
+
k5.metric("Avg latency", f"{summary.get('avg_latency_ms', 0):.0f} ms")
|
| 329 |
+
k6.metric("Cache hit rate", f"{cache.get('hit_rate', 0):.0%}" if cache.get('enabled') else "off")
|
| 330 |
+
|
| 331 |
+
st.divider()
|
| 332 |
+
|
| 333 |
+
# ββ Metric timeseries ββββββββββββββββββββββββββββββββββ
|
| 334 |
+
ts = mdata.get("timeseries", [])
|
| 335 |
+
if ts:
|
| 336 |
+
import pandas as pd
|
| 337 |
+
df_ts = pd.DataFrame(ts)
|
| 338 |
+
df_ts["hour"] = df_ts["hour_bucket"]
|
| 339 |
+
st.markdown("#### RAGAS metrics over time")
|
| 340 |
+
st.line_chart(
|
| 341 |
+
df_ts.set_index("hour")[["faithfulness", "answer_relevancy", "context_precision"]],
|
| 342 |
+
height=220,
|
| 343 |
+
)
|
| 344 |
+
else:
|
| 345 |
+
st.info("No evaluation data yet. Run some queries to populate the dashboard.")
|
| 346 |
+
|
| 347 |
+
st.divider()
|
| 348 |
+
|
| 349 |
+
col_left, col_right = st.columns(2, gap="large")
|
| 350 |
+
|
| 351 |
+
with col_left:
|
| 352 |
+
# ββ CRAG grade distribution ββββββββββββββββββββββββ
|
| 353 |
+
grade_dist = summary.get("crag_grade_dist", {})
|
| 354 |
+
if grade_dist:
|
| 355 |
+
import pandas as pd
|
| 356 |
+
st.markdown("#### CRAG grade distribution")
|
| 357 |
+
df_grades = pd.DataFrame(
|
| 358 |
+
list(grade_dist.items()), columns=["Grade", "Count"]
|
| 359 |
+
)
|
| 360 |
+
st.bar_chart(df_grades.set_index("Grade"), height=180)
|
| 361 |
+
|
| 362 |
+
# ββ Strategy distribution ββββββββββββββββββββββββββ
|
| 363 |
+
strat_dist = summary.get("strategy_dist", {})
|
| 364 |
+
if strat_dist:
|
| 365 |
+
import pandas as pd
|
| 366 |
+
st.markdown("#### Retrieval strategy mix")
|
| 367 |
+
rows = []
|
| 368 |
+
for strat_json, cnt in strat_dist.items():
|
| 369 |
+
try:
|
| 370 |
+
import json as _json
|
| 371 |
+
label = "+".join(_json.loads(strat_json)).upper()
|
| 372 |
+
except Exception:
|
| 373 |
+
label = strat_json
|
| 374 |
+
rows.append({"Strategy": label, "Count": cnt})
|
| 375 |
+
df_strat = pd.DataFrame(rows)
|
| 376 |
+
st.bar_chart(df_strat.set_index("Strategy"), height=180)
|
| 377 |
+
|
| 378 |
+
with col_right:
|
| 379 |
+
# ββ Cache stats ββββββββββββββββββββββββββββββββββββ
|
| 380 |
+
st.markdown("#### Cache")
|
| 381 |
+
if cache.get("enabled"):
|
| 382 |
+
c1, c2 = st.columns(2)
|
| 383 |
+
c1.metric("Hits", cache.get("hits", 0))
|
| 384 |
+
c2.metric("Misses", cache.get("misses", 0))
|
| 385 |
+
st.caption(f"TTL: {cache.get('ttl_s', 0)//60} min")
|
| 386 |
+
if st.button("ποΈ Flush cache"):
|
| 387 |
+
try:
|
| 388 |
+
r = requests.post(f"{REDIS_URL}/cache/flush", timeout=5)
|
| 389 |
+
st.success(f"Flushed {r.json().get('deleted', 0)} entries.")
|
| 390 |
+
st.session_state.pop("metrics_data", None)
|
| 391 |
+
except Exception as e:
|
| 392 |
+
st.error(str(e))
|
| 393 |
+
else:
|
| 394 |
+
st.caption("Redis not connected. Start Redis to enable caching.")
|
| 395 |
+
st.code("docker run -d -p 6379:6379 redis:7-alpine", language="bash")
|
| 396 |
+
|
| 397 |
+
st.divider()
|
| 398 |
+
|
| 399 |
+
# ββ Recent query log table βββββββββββββββββββββββββββββ
|
| 400 |
+
recent = mdata.get("recent", [])
|
| 401 |
+
if recent:
|
| 402 |
+
import pandas as pd
|
| 403 |
+
st.markdown("#### Recent queries")
|
| 404 |
+
rows = []
|
| 405 |
+
for r in recent[:50]:
|
| 406 |
+
rows.append({
|
| 407 |
+
"Query": r.get("query", "")[:60],
|
| 408 |
+
"Intent": r.get("intent", ""),
|
| 409 |
+
"CRAG": r.get("crag_grade", ""),
|
| 410 |
+
"Faithful": f"{r['faithfulness']:.2f}" if r.get("faithfulness") else "β",
|
| 411 |
+
"Relevancy": f"{r['answer_relevancy']:.2f}" if r.get("answer_relevancy") else "β",
|
| 412 |
+
"Precision": f"{r['context_precision']:.2f}" if r.get("context_precision") else "β",
|
| 413 |
+
"Latency ms": f"{r.get('latency_ms', 0):.0f}",
|
| 414 |
+
})
|
| 415 |
+
st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 419 |
+
# TAB 4 β SYSTEM HEALTH
|
| 420 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 421 |
with tab_system:
|
| 422 |
st.subheader("System health")
|
|
|
|
| 453 |
st.metric("Chunks indexed", stats.get("entity_count", "β"))
|
| 454 |
|
| 455 |
st.divider()
|
| 456 |
+
graph_stats = health.get("graph_stats", {})
|
| 457 |
+
if graph_stats:
|
| 458 |
+
col_d, col_e = st.columns(2)
|
| 459 |
+
with col_d:
|
| 460 |
+
st.metric("Graph nodes", graph_stats.get("nodes", "β"))
|
| 461 |
+
with col_e:
|
| 462 |
+
st.metric("Graph edges", graph_stats.get("edges", "β"))
|
| 463 |
+
st.divider()
|
| 464 |
st.markdown("**Raw health response**")
|
| 465 |
st.json(health)
|