SanskarModi commited on
Commit
765f1e4
·
1 Parent(s): 9ff42b8

updated code to include graph rag

Browse files
backend/app/api/routes_chat.py CHANGED
@@ -1,47 +1,45 @@
1
- """Chat routes for hybrid RAG-based Q&A."""
2
 
3
  from app.core.llm import llm_chat
4
  from app.core.prompts import build_rag_prompt
5
  from app.models.api import ChatRequest, ChatResponse, Citation
6
- from app.retrieval.hybrid import hybrid_search
7
  from fastapi import APIRouter
8
 
9
  router = APIRouter()
10
 
11
 
12
  @router.post("/ask", response_model=ChatResponse)
13
- def chat_hybrid(request: ChatRequest) -> ChatResponse:
14
- """Hybrid RAG Q&A endpoint (vector + BM25)."""
15
- # Hybrid retrieval
16
- chunks = hybrid_search(request.query, top_k=1)
17
 
18
- if not chunks:
19
  return ChatResponse(
20
  answer="I don't know based on the provided documents.",
21
  citations=[],
22
  )
23
 
24
- chunk = chunks[0]
 
 
 
 
25
 
26
- # Build prompt from ONLY the best chunk
27
  messages = build_rag_prompt(
28
- context=chunk.chunk.text,
29
  question=request.query,
30
  )
31
 
32
- # Ask LLM
33
  answer = llm_chat(messages=messages)
34
 
35
- # Cite ONLY what was used
36
  citations = [
37
  Citation(
38
- page_start=chunk.chunk.page_start,
39
- page_end=chunk.chunk.page_end,
40
- snippet=chunk.chunk.text[:300],
41
  )
 
42
  ]
43
 
44
- return ChatResponse(
45
- answer=answer,
46
- citations=citations,
47
- )
 
1
+ """Chat routes for Graph-RAG."""
2
 
3
  from app.core.llm import llm_chat
4
  from app.core.prompts import build_rag_prompt
5
  from app.models.api import ChatRequest, ChatResponse, Citation
6
+ from app.retrieval.retrieve import hybrid_graph_search
7
  from fastapi import APIRouter
8
 
9
  router = APIRouter()
10
 
11
 
12
  @router.post("/ask", response_model=ChatResponse)
13
+ def chat(request: ChatRequest) -> ChatResponse:
14
+ """Graph-augmented RAG endpoint."""
15
+ results = hybrid_graph_search(request.query, request.top_k)
 
16
 
17
+ if not results:
18
  return ChatResponse(
19
  answer="I don't know based on the provided documents.",
20
  citations=[],
21
  )
22
 
23
+ context = "\n\n".join(
24
+ f"(Pages {sc.chunk.page_start}-\
25
+ {sc.chunk.page_end})\n{sc.chunk.text}"
26
+ for sc in results
27
+ )
28
 
 
29
  messages = build_rag_prompt(
30
+ context=context,
31
  question=request.query,
32
  )
33
 
 
34
  answer = llm_chat(messages=messages)
35
 
 
36
  citations = [
37
  Citation(
38
+ page_start=sc.chunk.page_start,
39
+ page_end=sc.chunk.page_end,
40
+ snippet=sc.chunk.text[:300],
41
  )
42
+ for sc in results
43
  ]
44
 
45
+ return ChatResponse(answer=answer, citations=citations)
 
 
 
backend/app/ingestion/entities.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load entities from chunk texts."""
2
+
3
+ from typing import List, Set
4
+
5
+ import spacy
6
+
7
+ # Load once at module import
8
+ NLP = spacy.load("en_core_web_sm")
9
+
10
+ # Entity labels we accept.
11
+ # Keep this BROAD on purpose.
12
+ ALLOWED_ENTITY_LABELS = {
13
+ "PERSON",
14
+ "ORG",
15
+ "GPE",
16
+ "LOC",
17
+ "PRODUCT",
18
+ "EVENT",
19
+ "WORK_OF_ART",
20
+ "LAW",
21
+ "LANGUAGE",
22
+ "NORP",
23
+ "FAC",
24
+ }
25
+
26
+
27
+ def extract_entities(text: str) -> List[str]:
28
+ """Extract entities from text using spaCy.
29
+
30
+ Rules:
31
+ - Deterministic (no LLM)
32
+ - Preserve surface form
33
+ - Deduplicate
34
+ - Ignore very short / noisy entities
35
+ """
36
+ if not text.strip():
37
+ return []
38
+
39
+ doc = NLP(text)
40
+
41
+ entities: Set[str] = set()
42
+
43
+ for ent in doc.ents:
44
+ if ent.label_ not in ALLOWED_ENTITY_LABELS:
45
+ continue
46
+
47
+ value = ent.text.strip()
48
+
49
+ # Drop trivial junk
50
+ if len(value) < 3 or len(value.split()) > 5:
51
+ continue
52
+
53
+ entities.add(value)
54
+
55
+ return sorted(entities)
backend/app/ingestion/pipeline.py CHANGED
@@ -5,26 +5,34 @@ from typing import List
5
 
6
  from app.ingestion.chunking import chunk_segments
7
  from app.ingestion.cleaning import clean_text
 
8
  from app.ingestion.indexing import index_chunks
9
  from app.ingestion.pdf_loader import extract_pages
10
  from app.models.ingestion import Chunk, RawSegment
 
 
11
  from app.retrieval.keyword_index import build_bm25_index
12
 
13
 
14
  def ingest_pdf(file_path: Path, doc_id: str) -> List[Chunk]:
15
- """Ingest a PDF document into indexed chunks."""
16
- raw_segments = extract_pages(file_path=file_path, doc_id=doc_id)
17
  cleaned_segments = _clean_segments(raw_segments)
 
18
  chunks = chunk_segments(cleaned_segments)
19
 
 
 
 
 
 
20
  index_chunks(chunks)
21
  build_bm25_index(chunks)
22
-
23
  return chunks
24
 
25
 
26
  def _clean_segments(segments: List[RawSegment]) -> List[RawSegment]:
27
- """Apply text cleaning to raw segments."""
28
  return [
29
  RawSegment(
30
  doc_id=s.doc_id,
 
5
 
6
  from app.ingestion.chunking import chunk_segments
7
  from app.ingestion.cleaning import clean_text
8
+ from app.ingestion.entities import extract_entities
9
  from app.ingestion.indexing import index_chunks
10
  from app.ingestion.pdf_loader import extract_pages
11
  from app.models.ingestion import Chunk, RawSegment
12
+ from app.retrieval.chunk_registry import register_chunks
13
+ from app.retrieval.graph_utils import index_entities
14
  from app.retrieval.keyword_index import build_bm25_index
15
 
16
 
17
  def ingest_pdf(file_path: Path, doc_id: str) -> List[Chunk]:
18
+ """Ingest a PDF document."""
19
+ raw_segments = extract_pages(file_path, doc_id)
20
  cleaned_segments = _clean_segments(raw_segments)
21
+
22
  chunks = chunk_segments(cleaned_segments)
23
 
24
+ for chunk in chunks:
25
+ chunk.entities = extract_entities(chunk.text)
26
+
27
+ register_chunks(chunks)
28
+ index_entities(chunks)
29
  index_chunks(chunks)
30
  build_bm25_index(chunks)
 
31
  return chunks
32
 
33
 
34
  def _clean_segments(segments: List[RawSegment]) -> List[RawSegment]:
35
+ """Apply text cleaning."""
36
  return [
37
  RawSegment(
38
  doc_id=s.doc_id,
backend/app/models/ingestion.py CHANGED
@@ -1,10 +1,12 @@
1
- """Pydantic models for Ingestion artifacts."""
2
 
3
- from pydantic import BaseModel
 
 
4
 
5
 
6
  class RawSegment(BaseModel):
7
- """Represents raw page-level text extracted from a PDF."""
8
 
9
  doc_id: str
10
  page: int
@@ -19,3 +21,4 @@ class Chunk(BaseModel):
19
  page_start: int
20
  page_end: int
21
  text: str
 
 
1
+ """Pydantic models for ingestion artifacts."""
2
 
3
+ from typing import List
4
+
5
+ from pydantic import BaseModel, Field
6
 
7
 
8
  class RawSegment(BaseModel):
9
+ """Represents raw page-level text extracted from a document."""
10
 
11
  doc_id: str
12
  page: int
 
21
  page_start: int
22
  page_end: int
23
  text: str
24
+ entities: List[str] = Field(default_factory=list)
backend/app/retrieval/chunk_registry.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """In-memory chunk registry.
2
+
3
+ Single source of truth for all ingested chunks.
4
+ Used by graph-based retrieval to map entities back to chunks.
5
+
6
+ Note:
7
+ - Ephemeral by design (non-persistent)
8
+ - Rebuilt on each ingestion cycle
9
+ """
10
+
11
+ from typing import Dict, List
12
+
13
+ from app.models.ingestion import Chunk
14
+
15
+ _CHUNKS: Dict[str, Chunk] = {}
16
+
17
+
18
+ def register_chunks(chunks: List[Chunk]) -> None:
19
+ """Register chunks in memory."""
20
+ for chunk in chunks:
21
+ _CHUNKS[chunk.chunk_id] = chunk
22
+
23
+
24
+ def get_chunks() -> List[Chunk]:
25
+ """Return all registered chunks."""
26
+ return list(_CHUNKS.values())
27
+
28
+
29
+ def clear_chunks() -> None:
30
+ """Clear registry (useful for tests)."""
31
+ _CHUNKS.clear()
backend/app/retrieval/graph_utils.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Graph utilities for Graph-RAG.
2
+
3
+ Responsibilities:
4
+ - Build entity co-occurrence graph
5
+ - Index entity → chunk mappings
6
+ - Extract entities from queries
7
+ - Expand entities via graph traversal
8
+ - Recall chunks via entity relationships
9
+ """
10
+
11
+ from collections import defaultdict
12
+ from typing import Dict, Iterable, List, Set
13
+
14
+ import networkx as nx
15
+ from app.models.ingestion import Chunk
16
+
17
+ # In-memory entity → chunk index
18
+ _ENTITY_TO_CHUNKS: Dict[str, Set[str]] = defaultdict(set)
19
+
20
+
21
+ def index_entities(chunks: List[Chunk]) -> None:
22
+ """Index entities to chunk IDs.
23
+
24
+ Called once during ingestion.
25
+ """
26
+ for chunk in chunks:
27
+ for entity in chunk.entities:
28
+ _ENTITY_TO_CHUNKS[entity].add(chunk.chunk_id)
29
+
30
+
31
+ def build_graph(chunks: List[Chunk]) -> nx.Graph:
32
+ """Build an entity co-occurrence graph.
33
+
34
+ Nodes: entities
35
+ Edges: co-occurrence within the same chunk
36
+ """
37
+ graph = nx.Graph()
38
+
39
+ for chunk in chunks:
40
+ entities = chunk.entities
41
+
42
+ for entity in entities:
43
+ graph.add_node(entity)
44
+
45
+ for i, e1 in enumerate(entities):
46
+ for e2 in entities[i + 1 :]:
47
+ if graph.has_edge(e1, e2):
48
+ graph[e1][e2]["weight"] += 1
49
+ else:
50
+ graph.add_edge(e1, e2, weight=1)
51
+
52
+ return graph
53
+
54
+
55
+ def extract_query_entities(text: str, nlp) -> Set[str]:
56
+ """Extract entities from a user query.
57
+
58
+ Deterministic (spaCy-based).
59
+ """
60
+ if not text.strip():
61
+ return set()
62
+
63
+ doc = nlp(text)
64
+ return {ent.text.strip() for ent in doc.ents if len(ent.text.strip()) >= 3}
65
+
66
+
67
+ def expand_entities(
68
+ graph: nx.Graph,
69
+ entities: Iterable[str],
70
+ hops: int = 1,
71
+ ) -> Set[str]:
72
+ """Expand entities via graph traversal.
73
+
74
+ hops=1 → direct neighbors
75
+ hops=2 → neighbors of neighbors
76
+ """
77
+ expanded: Set[str] = set(entities)
78
+
79
+ for _ in range(hops):
80
+ neighbors: Set[str] = set()
81
+ for entity in expanded:
82
+ if entity in graph:
83
+ neighbors.update(graph.neighbors(entity))
84
+ expanded |= neighbors
85
+
86
+ return expanded
87
+
88
+
89
+ def chunks_from_entities(
90
+ chunks: List[Chunk],
91
+ entities: Set[str],
92
+ ) -> List[Chunk]:
93
+ """Recall chunks mentioning any of the given entities.
94
+
95
+ THIS is the Graph-RAG recall step.
96
+ """
97
+ matched_chunk_ids: Set[str] = set()
98
+
99
+ for entity in entities:
100
+ matched_chunk_ids |= _ENTITY_TO_CHUNKS.get(entity, set())
101
+
102
+ return [chunk for chunk in chunks if chunk.chunk_id in matched_chunk_ids]
backend/app/retrieval/hybrid.py DELETED
@@ -1,31 +0,0 @@
1
- """Hybrid retrieval (Vector + BM25)."""
2
-
3
- from typing import List
4
-
5
- from app.models.retrieval import ScoredChunk
6
- from app.retrieval.keyword_index import bm25_search
7
- from app.retrieval.normalize import normalize_scores
8
- from app.retrieval.vector_store import vector_search
9
-
10
-
11
- def hybrid_search(query: str, top_k: int = 5) -> List[ScoredChunk]:
12
- """Run hybrid retrieval with score fusion."""
13
- vector_chunks = normalize_scores(vector_search(query, top_k=10))
14
- bm25_chunks = normalize_scores(bm25_search(query, top_k=10))
15
-
16
- merged: dict[str, ScoredChunk] = {}
17
-
18
- for chunk in vector_chunks + bm25_chunks:
19
- cid = chunk.chunk.chunk_id
20
- if cid not in merged:
21
- merged[cid] = chunk
22
- else:
23
- merged[cid].score += chunk.score
24
-
25
- ranked = sorted(
26
- merged.values(),
27
- key=lambda x: x.score,
28
- reverse=True,
29
- )
30
-
31
- return ranked[:top_k]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/app/retrieval/normalize.py DELETED
@@ -1,25 +0,0 @@
1
- """Score normalization utilities."""
2
-
3
- from typing import List
4
-
5
- from app.models.retrieval import ScoredChunk
6
-
7
-
8
- def normalize_scores(chunks: List[ScoredChunk]) -> List[ScoredChunk]:
9
- """Min-max normalize scores to [0, 1]."""
10
- if not chunks:
11
- return []
12
-
13
- scores = [c.score for c in chunks]
14
- min_score = min(scores)
15
- max_score = max(scores)
16
-
17
- if min_score == max_score:
18
- for c in chunks:
19
- c.score = 1.0
20
- return chunks
21
-
22
- for c in chunks:
23
- c.score = (c.score - min_score) / (max_score - min_score)
24
-
25
- return chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/app/retrieval/retrieve.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unified Hybrid + Graph retrieval.
2
+
3
+ Pipeline:
4
+ Vector → BM25 → merge → graph recall → rank
5
+ """
6
+
7
+ from typing import Dict, List
8
+
9
+ from app.ingestion.entities import NLP
10
+ from app.models.retrieval import ScoredChunk
11
+ from app.retrieval.chunk_registry import get_chunks
12
+ from app.retrieval.graph_utils import (
13
+ build_graph,
14
+ chunks_from_entities,
15
+ expand_entities,
16
+ extract_query_entities,
17
+ )
18
+ from app.retrieval.keyword_index import bm25_search
19
+ from app.retrieval.vector_store import vector_search
20
+
21
+
22
+ def hybrid_graph_search(query: str, top_k: int) -> List[ScoredChunk]:
23
+ """Hybrid + Graph-RAG retrieval.
24
+
25
+ Important:
26
+ - top_k controls FINAL context size
27
+ - retrieval breadth is independent
28
+ """
29
+ # 1. Broad seed retrieval
30
+ seed_k = max(top_k * 4, 8)
31
+
32
+ vector_results = vector_search(query, top_k=seed_k)
33
+ bm25_results = bm25_search(query, top_k=seed_k)
34
+
35
+ combined: Dict[str, ScoredChunk] = {sc.chunk.chunk_id: sc for sc in vector_results}
36
+
37
+ for sc in bm25_results:
38
+ combined.setdefault(sc.chunk.chunk_id, sc)
39
+
40
+ # 2. Graph recall expansion
41
+ all_chunks = get_chunks()
42
+ graph = build_graph(all_chunks)
43
+
44
+ query_entities = extract_query_entities(query, NLP)
45
+
46
+ if query_entities:
47
+ expanded_entities = expand_entities(graph, query_entities, hops=1)
48
+ graph_chunks = chunks_from_entities(all_chunks, expanded_entities)
49
+
50
+ for chunk in graph_chunks:
51
+ if chunk.chunk_id not in combined:
52
+ combined[chunk.chunk_id] = ScoredChunk(
53
+ chunk=chunk,
54
+ score=0.25, # low but non-zero recall score
55
+ )
56
+
57
+ # 3. Rank and return
58
+ results = list(combined.values())
59
+ results.sort(key=lambda x: x.score, reverse=True)
60
+
61
+ return results[:top_k]
requirements.txt CHANGED
@@ -19,6 +19,7 @@ chromadb==0.5.0
19
  # Text Processing & NLP
20
  pymupdf==1.24.7
21
  spacy==3.7.4
 
22
  sentence-transformers==2.6.1
23
  rank-bm25==0.2.2
24
  whoosh==2.7.4
 
19
  # Text Processing & NLP
20
  pymupdf==1.24.7
21
  spacy==3.7.4
22
+ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl
23
  sentence-transformers==2.6.1
24
  rank-bm25==0.2.2
25
  whoosh==2.7.4