Paramjit Singh commited on
Commit
3e08504
·
unverified ·
2 Parent(s): efb7f42c752a7a

Merge pull request #255 from Kishalll/feature/graphrag-knowledge-graph

Browse files
.env.example CHANGED
@@ -122,6 +122,16 @@ HF_TOKEN=your_huggingface_token_here
122
 
123
  # ── RAG Config (Optional — defaults shown) ───────────
124
 
 
 
 
 
 
 
 
 
 
 
125
  # ── ChromaDB (Vector Store) ─────────────────────────────────
126
 
127
  # Directory where ChromaDB persists its vector index to disk.
 
122
 
123
  # ── RAG Config (Optional — defaults shown) ───────────
124
 
125
+ # ── Knowledge Graph / GraphRAG (Optional — defaults shown) ─────────────────
126
+
127
+ # Directory where GraphRAG stores per-document knowledge graphs.
128
+ # Optional — defaults to "./data/graphs"
129
+ # GRAPH_PERSIST_DIR=./data/graphs
130
+
131
+ # Maximum number of graph relationships appended to the RAG prompt.
132
+ # Optional — defaults to 12
133
+ # GRAPH_MAX_RELATIONSHIPS=12
134
+
135
  # ── ChromaDB (Vector Store) ─────────────────────────────────
136
 
137
  # Directory where ChromaDB persists its vector index to disk.
Dockerfile CHANGED
@@ -33,7 +33,8 @@ RUN python -m venv "$VIRTUAL_ENV"
33
 
34
  COPY backend/requirements.txt ./requirements.txt
35
  RUN pip install --no-cache-dir --upgrade pip setuptools wheel && \
36
- pip install --no-cache-dir -r requirements.txt
 
37
 
38
  # --------------------------------------------------------
39
  # Stage 3: Runtime image with only app code and artifacts
@@ -68,7 +69,7 @@ COPY backend/__init__.py ./backend/__init__.py
68
  COPY --from=frontend-builder /app/frontend/out ./frontend/out
69
 
70
  # Create data directories with proper permissions
71
- RUN mkdir -p /app/data/uploads /app/data/chroma_db /app/data/huggingface && \
72
  chown -R appuser:appuser /app
73
 
74
  # Copy entrypoint
 
33
 
34
  COPY backend/requirements.txt ./requirements.txt
35
  RUN pip install --no-cache-dir --upgrade pip setuptools wheel && \
36
+ pip install --no-cache-dir -r requirements.txt && \
37
+ python -m spacy download en_core_web_sm
38
 
39
  # --------------------------------------------------------
40
  # Stage 3: Runtime image with only app code and artifacts
 
69
  COPY --from=frontend-builder /app/frontend/out ./frontend/out
70
 
71
  # Create data directories with proper permissions
72
+ RUN mkdir -p /app/data/uploads /app/data/chroma_db /app/data/graphs /app/data/huggingface && \
73
  chown -R appuser:appuser /app
74
 
75
  # Copy entrypoint
backend/app/config.py CHANGED
@@ -45,6 +45,22 @@ class Settings(BaseSettings):
45
  TOP_K_RETRIEVAL: int = 10
46
  TOP_K_RERANK: int = 5
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  # ── Embeddings (local HuggingFace model) ─────────────
49
  EMBEDDING_MODEL: str = "sentence-transformers/all-MiniLM-L6-v2"
50
  EMBEDDING_DIMENSION: int = 384
 
45
  TOP_K_RETRIEVAL: int = 10
46
  TOP_K_RERANK: int = 5
47
 
48
+ # ── Knowledge Graph (GraphRAG) ───────────────────────
49
+ GRAPH_PERSIST_DIR: str = "./data/graphs"
50
+ GRAPH_ENTITY_LABELS: set = {
51
+ "PERSON",
52
+ "ORG",
53
+ "GPE",
54
+ "LOC",
55
+ "PRODUCT",
56
+ "EVENT",
57
+ "WORK_OF_ART",
58
+ "LAW",
59
+ "NORP",
60
+ "FAC",
61
+ }
62
+ GRAPH_MAX_RELATIONSHIPS: int = 12
63
+
64
  # ── Embeddings (local HuggingFace model) ─────────────
65
  EMBEDDING_MODEL: str = "sentence-transformers/all-MiniLM-L6-v2"
66
  EMBEDDING_DIMENSION: int = 384
backend/app/rag/agent.py CHANGED
@@ -9,6 +9,7 @@ from typing import List, Dict, Any, Optional, Generator
9
  from huggingface_hub import InferenceClient
10
  from app.config import get_settings
11
  from app.rag.retriever import retrieve
 
12
  from app.rag.prompts import SYSTEM_PROMPT, RAG_PROMPT_TEMPLATE, GREETING_PROMPT
13
  from app.rag.tracing import trace_function
14
 
@@ -48,6 +49,26 @@ def build_context(chunks: List[Dict[str, Any]]) -> str:
48
  return "\n\n---\n\n".join(context_parts)
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def _chat_messages(system: str, user_content: str) -> list:
52
  """Build messages list for chat completion API."""
53
  return [
@@ -108,7 +129,12 @@ def generate_answer(
108
 
109
  # ── Build prompt ─────────────────────────────────
110
  # Format retrieved chunks into a readable context block, then inject into the RAG prompt template
111
- context = build_context(chunks)
 
 
 
 
 
112
  user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
113
  messages = _chat_messages(SYSTEM_PROMPT, user_content)
114
 
@@ -222,7 +248,12 @@ def generate_answer_stream(
222
 
223
  # ── Build prompt ─────────────────────────────────
224
  # Format retrieved chunks into a readable context block, then inject into the RAG prompt template
225
- context = build_context(chunks)
 
 
 
 
 
226
  user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
227
  messages = _chat_messages(SYSTEM_PROMPT, user_content)
228
 
 
9
  from huggingface_hub import InferenceClient
10
  from app.config import get_settings
11
  from app.rag.retriever import retrieve
12
+ from app.rag.graph_retriever import get_entity_context
13
  from app.rag.prompts import SYSTEM_PROMPT, RAG_PROMPT_TEMPLATE, GREETING_PROMPT
14
  from app.rag.tracing import trace_function
15
 
 
49
  return "\n\n---\n\n".join(context_parts)
50
 
51
 
52
+ def build_augmented_context(
53
+ chunks: List[Dict[str, Any]],
54
+ question: str,
55
+ user_id: str,
56
+ document_id: Optional[str] = None,
57
+ ) -> str:
58
+ """Combine vector-retrieved excerpts with GraphRAG relationships."""
59
+ context = build_context(chunks)
60
+ graph_context = get_entity_context(
61
+ query=question,
62
+ user_id=user_id,
63
+ document_id=document_id,
64
+ )
65
+
66
+ if not graph_context:
67
+ return context
68
+
69
+ return f"{context}\n\n---\n\n{graph_context}"
70
+
71
+
72
  def _chat_messages(system: str, user_content: str) -> list:
73
  """Build messages list for chat completion API."""
74
  return [
 
129
 
130
  # ── Build prompt ─────────────────────────────────
131
  # Format retrieved chunks into a readable context block, then inject into the RAG prompt template
132
+ context = build_augmented_context(
133
+ chunks=chunks,
134
+ question=question,
135
+ user_id=user_id,
136
+ document_id=document_id,
137
+ )
138
  user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
139
  messages = _chat_messages(SYSTEM_PROMPT, user_content)
140
 
 
248
 
249
  # ── Build prompt ─────────────────────────────────
250
  # Format retrieved chunks into a readable context block, then inject into the RAG prompt template
251
+ context = build_augmented_context(
252
+ chunks=chunks,
253
+ question=question,
254
+ user_id=user_id,
255
+ document_id=document_id,
256
+ )
257
  user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
258
  messages = _chat_messages(SYSTEM_PROMPT, user_content)
259
 
backend/app/rag/graph_builder.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Knowledge graph construction and persistence for GraphRAG.
3
+ """
4
+ import json
5
+ import logging
6
+ import re
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Iterable, List, Optional
10
+
11
+ import networkx as nx
12
+
13
+ from app.config import get_settings
14
+
15
+ logger = logging.getLogger(__name__)
16
+ settings = get_settings()
17
+
18
+ _nlp = None
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class Entity:
23
+ id: str
24
+ text: str
25
+ label: str
26
+
27
+
28
+ def _safe_id(value: str) -> str:
29
+ safe = re.sub(r"[^A-Za-z0-9_.-]+", "_", value).strip("._")
30
+ return safe or "unknown"
31
+
32
+
33
+ def get_graph_path(user_id: str, document_id: str) -> Path:
34
+ """Return the on-disk graph path for one user/document pair."""
35
+ filename = f"{_safe_id(user_id)}_{_safe_id(document_id)}.json"
36
+ return Path(settings.GRAPH_PERSIST_DIR) / filename
37
+
38
+
39
+ def iter_graph_paths(user_id: str) -> Iterable[Path]:
40
+ """Yield every persisted graph path for a user."""
41
+ graph_dir = Path(settings.GRAPH_PERSIST_DIR)
42
+ if not graph_dir.exists():
43
+ return []
44
+
45
+ prefix = f"{_safe_id(user_id)}_"
46
+ return sorted(graph_dir.glob(f"{prefix}*.json"))
47
+
48
+
49
+ def _get_nlp():
50
+ """Load the spaCy English NER model lazily."""
51
+ global _nlp
52
+ if _nlp is None:
53
+ import spacy
54
+
55
+ try:
56
+ _nlp = spacy.load("en_core_web_sm")
57
+ except OSError as exc:
58
+ raise RuntimeError(
59
+ "spaCy model 'en_core_web_sm' is required for GraphRAG entity extraction. "
60
+ "Install it with: python -m spacy download en_core_web_sm"
61
+ ) from exc
62
+ return _nlp
63
+
64
+
65
+ def _entity_id(text: str, label: str) -> str:
66
+ normalized = " ".join(text.split()).casefold()
67
+ return f"{label}:{normalized}"
68
+
69
+
70
+ def extract_entities(text: str) -> List[Entity]:
71
+ """Extract configured named entities from text."""
72
+ if not text or not text.strip():
73
+ return []
74
+
75
+ doc = _get_nlp()(text)
76
+ entities: Dict[str, Entity] = {}
77
+
78
+ for ent in doc.ents:
79
+ value = " ".join(ent.text.split()).strip()
80
+ if not value or ent.label_ not in settings.GRAPH_ENTITY_LABELS:
81
+ continue
82
+
83
+ entity_id = _entity_id(value, ent.label_)
84
+ entities.setdefault(
85
+ entity_id,
86
+ Entity(id=entity_id, text=value, label=ent.label_),
87
+ )
88
+
89
+ return list(entities.values())
90
+
91
+
92
+ def build_graph(chunks: List[Dict[str, Any]]) -> nx.Graph:
93
+ """Build an entity co-occurrence graph from document chunks."""
94
+ graph = nx.Graph()
95
+
96
+ for chunk in chunks:
97
+ text = chunk.get("text", "")
98
+ page = chunk.get("page")
99
+ chunk_index = chunk.get("chunk_index")
100
+ entities = extract_entities(text)
101
+
102
+ for entity in entities:
103
+ if graph.has_node(entity.id):
104
+ graph.nodes[entity.id]["mentions"] += 1
105
+ graph.nodes[entity.id]["pages"].add(page)
106
+ graph.nodes[entity.id]["chunks"].add(chunk_index)
107
+ else:
108
+ graph.add_node(
109
+ entity.id,
110
+ name=entity.text,
111
+ label=entity.label,
112
+ mentions=1,
113
+ pages={page},
114
+ chunks={chunk_index},
115
+ )
116
+
117
+ for left_index, left in enumerate(entities):
118
+ for right in entities[left_index + 1:]:
119
+ if graph.has_edge(left.id, right.id):
120
+ graph[left.id][right.id]["weight"] += 1
121
+ graph[left.id][right.id]["pages"].add(page)
122
+ graph[left.id][right.id]["chunks"].add(chunk_index)
123
+ else:
124
+ graph.add_edge(
125
+ left.id,
126
+ right.id,
127
+ weight=1,
128
+ pages={page},
129
+ chunks={chunk_index},
130
+ )
131
+
132
+ _convert_sets_for_json(graph)
133
+ return graph
134
+
135
+
136
+ def _convert_sets_for_json(graph: nx.Graph) -> None:
137
+ for _, data in graph.nodes(data=True):
138
+ data["pages"] = sorted(item for item in data.get("pages", []) if item is not None)
139
+ data["chunks"] = sorted(item for item in data.get("chunks", []) if item is not None)
140
+
141
+ for _, _, data in graph.edges(data=True):
142
+ data["pages"] = sorted(item for item in data.get("pages", []) if item is not None)
143
+ data["chunks"] = sorted(item for item in data.get("chunks", []) if item is not None)
144
+
145
+
146
+ def save_graph(graph: nx.Graph, user_id: str, document_id: str) -> Path:
147
+ """Persist a graph to disk as node-link JSON."""
148
+ graph_path = get_graph_path(user_id, document_id)
149
+ graph_path.parent.mkdir(parents=True, exist_ok=True)
150
+
151
+ data = nx.node_link_data(graph)
152
+ data["metadata"] = {
153
+ "user_id": user_id,
154
+ "document_id": document_id,
155
+ "node_count": graph.number_of_nodes(),
156
+ "edge_count": graph.number_of_edges(),
157
+ }
158
+
159
+ graph_path.write_text(json.dumps(data, ensure_ascii=True, indent=2), encoding="utf-8")
160
+ logger.info(
161
+ "Saved knowledge graph for document %s with %s nodes and %s edges",
162
+ document_id,
163
+ graph.number_of_nodes(),
164
+ graph.number_of_edges(),
165
+ )
166
+ return graph_path
167
+
168
+
169
+ def load_graph(user_id: str, document_id: str) -> Optional[nx.Graph]:
170
+ """Load a persisted graph for one user/document pair."""
171
+ return load_graph_path(get_graph_path(user_id, document_id))
172
+
173
+
174
+ def load_graph_path(graph_path: Path) -> Optional[nx.Graph]:
175
+ """Load a graph from a concrete JSON path."""
176
+ if not graph_path.exists():
177
+ return None
178
+
179
+ data = json.loads(graph_path.read_text(encoding="utf-8"))
180
+ return nx.node_link_graph(data)
181
+
182
+
183
+ def delete_graph(user_id: str, document_id: str) -> None:
184
+ """Delete a persisted graph file if it exists."""
185
+ get_graph_path(user_id, document_id).unlink(missing_ok=True)
backend/app/rag/graph_retriever.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Knowledge graph retrieval for augmenting RAG context.
3
+ """
4
+ import logging
5
+ from typing import Dict, Iterable, List, Optional, Set, Tuple
6
+
7
+ import networkx as nx
8
+
9
+ from app.config import get_settings
10
+ from app.rag.graph_builder import (
11
+ extract_entities,
12
+ iter_graph_paths,
13
+ load_graph,
14
+ load_graph_path,
15
+ )
16
+
17
+ logger = logging.getLogger(__name__)
18
+ settings = get_settings()
19
+
20
+
21
+ def _candidate_graphs(user_id: str, document_id: Optional[str]) -> Iterable[nx.Graph]:
22
+ if document_id:
23
+ graph = load_graph(user_id, document_id)
24
+ return [graph] if graph is not None else []
25
+
26
+ graphs = []
27
+ for path in iter_graph_paths(user_id):
28
+ graph = load_graph_path(path)
29
+ if graph is not None:
30
+ graphs.append(graph)
31
+ return graphs
32
+
33
+
34
+ def _node_name(graph: nx.Graph, node_id: str) -> str:
35
+ return graph.nodes[node_id].get("name", node_id.split(":", 1)[-1])
36
+
37
+
38
+ def _match_query_nodes(graph: nx.Graph, query: str) -> Set[str]:
39
+ query_entities = extract_entities(query)
40
+ matched = {entity.id for entity in query_entities if graph.has_node(entity.id)}
41
+
42
+ if matched:
43
+ return matched
44
+
45
+ query_text = query.casefold()
46
+ for node_id, data in graph.nodes(data=True):
47
+ name = data.get("name", "").casefold()
48
+ if name and name in query_text:
49
+ matched.add(node_id)
50
+
51
+ return matched
52
+
53
+
54
+ def _format_pages(pages: List[int]) -> str:
55
+ if not pages:
56
+ return "unknown pages"
57
+ if len(pages) == 1:
58
+ return f"page {pages[0]}"
59
+ return "pages " + ", ".join(str(page) for page in pages[:4])
60
+
61
+
62
+ def _relationship_key(left: str, right: str) -> Tuple[str, str]:
63
+ return tuple(sorted((left, right)))
64
+
65
+
66
+ def get_entity_context(
67
+ query: str,
68
+ user_id: str,
69
+ document_id: Optional[str] = None,
70
+ ) -> str:
71
+ """Return compact graph relationship context relevant to the query."""
72
+ relationships: Dict[Tuple[str, str], Dict[str, object]] = {}
73
+
74
+ try:
75
+ graphs = _candidate_graphs(user_id=user_id, document_id=document_id)
76
+ for graph in graphs:
77
+ matched_nodes = _match_query_nodes(graph, query)
78
+
79
+ for node_id in matched_nodes:
80
+ neighbors = sorted(
81
+ graph.neighbors(node_id),
82
+ key=lambda neighbor: graph[node_id][neighbor].get("weight", 0),
83
+ reverse=True,
84
+ )
85
+ for neighbor_id in neighbors:
86
+ edge = graph[node_id][neighbor_id]
87
+ left = _node_name(graph, node_id)
88
+ right = _node_name(graph, neighbor_id)
89
+ key = _relationship_key(left.casefold(), right.casefold())
90
+ existing = relationships.setdefault(
91
+ key,
92
+ {
93
+ "left": left,
94
+ "right": right,
95
+ "weight": 0,
96
+ "pages": set(),
97
+ },
98
+ )
99
+ existing["weight"] = int(existing["weight"]) + int(edge.get("weight", 1))
100
+ existing["pages"].update(edge.get("pages", []))
101
+ except Exception as exc:
102
+ logger.warning("GraphRAG context retrieval failed: %s", exc)
103
+ return ""
104
+
105
+ if not relationships:
106
+ return ""
107
+
108
+ ranked = sorted(
109
+ relationships.values(),
110
+ key=lambda item: int(item["weight"]),
111
+ reverse=True,
112
+ )[: settings.GRAPH_MAX_RELATIONSHIPS]
113
+
114
+ lines = ["## Knowledge Graph Context"]
115
+ for item in ranked:
116
+ pages = sorted(item["pages"])
117
+ lines.append(
118
+ f"- {item['left']} is related to {item['right']} "
119
+ f"through document co-occurrence on {_format_pages(pages)} "
120
+ f"(strength: {item['weight']})."
121
+ )
122
+
123
+ return "\n".join(lines)
backend/app/routes/documents.py CHANGED
@@ -172,6 +172,15 @@ def _ingest_document(document_id: str, filepath: str, original_name: str, user_i
172
  db.commit()
173
  return
174
 
 
 
 
 
 
 
 
 
 
175
  # Store embeddings in ChromaDB
176
  chunk_count = store_chunks(
177
  chunks=chunks,
@@ -629,6 +638,14 @@ def delete_document(
629
  except Exception as e:
630
  logger.warning(f"Error deleting vectors: {e}")
631
 
 
 
 
 
 
 
 
 
632
  # Delete from database (cascades to chat messages)
633
  db.delete(doc)
634
  db.commit()
 
172
  db.commit()
173
  return
174
 
175
+ # Build and persist a lightweight entity co-occurrence graph for GraphRAG.
176
+ try:
177
+ from app.rag.graph_builder import build_graph, save_graph
178
+
179
+ graph = build_graph(chunks)
180
+ save_graph(graph, user_id=user_id, document_id=document_id)
181
+ except Exception as e:
182
+ logger.warning(f"Could not build knowledge graph for document {document_id}: {e}")
183
+
184
  # Store embeddings in ChromaDB
185
  chunk_count = store_chunks(
186
  chunks=chunks,
 
638
  except Exception as e:
639
  logger.warning(f"Error deleting vectors: {e}")
640
 
641
+ # Delete persisted knowledge graph
642
+ try:
643
+ from app.rag.graph_builder import delete_graph
644
+
645
+ delete_graph(user_id=user.id, document_id=document_id)
646
+ except Exception as e:
647
+ logger.warning(f"Error deleting knowledge graph: {e}")
648
+
649
  # Delete from database (cascades to chat messages)
650
  db.delete(doc)
651
  db.commit()
backend/requirements.txt CHANGED
@@ -41,6 +41,9 @@ transformers
41
 
42
  # Vector Database
43
  chromadb
 
 
 
44
 
45
  # LLM Inference
46
  huggingface-hub
 
41
 
42
  # Vector Database
43
  chromadb
44
+ networkx>=3.3
45
+ spacy>=3.7
46
+ neo4j>=5.0
47
 
48
  # LLM Inference
49
  huggingface-hub
backend/tests/test_documents.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  def test_api_health(client):
2
  response = client.get("/api/health")
3
 
@@ -32,3 +38,76 @@ def test_upload_rejects_unsupported_extension_before_deep_validation(client, aut
32
 
33
  assert response.status_code == 400
34
  assert "not supported" in response.json()["detail"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+
3
+ from app.models import Document
4
+ from app.routes.documents import _ingest_document
5
+
6
+
7
  def test_api_health(client):
8
  response = client.get("/api/health")
9
 
 
38
 
39
  assert response.status_code == 400
40
  assert "not supported" in response.json()["detail"]
41
+
42
+
43
+ def test_ingest_document_builds_and_saves_graph(db_session, monkeypatch, tmp_path, user):
44
+ document = Document(
45
+ user_id=user.id,
46
+ filename="graph.txt",
47
+ original_name="graph.txt",
48
+ file_size=128,
49
+ status="pending",
50
+ )
51
+ db_session.add(document)
52
+ db_session.commit()
53
+ db_session.refresh(document)
54
+ user_id = user.id
55
+ document_id = document.id
56
+ chunks = [{"text": "OpenAI works with Microsoft.", "page": 1, "chunk_index": 0}]
57
+ saved = {}
58
+
59
+ monkeypatch.setattr("app.routes.documents.get_page_count", lambda filepath: 1)
60
+ monkeypatch.setattr("app.routes.documents.chunk_document", lambda filepath: chunks)
61
+ monkeypatch.setattr("app.routes.documents.store_chunks", lambda **kwargs: len(chunks))
62
+ monkeypatch.setattr("app.database.SessionLocal", lambda: db_session)
63
+
64
+ fake_summary = types.ModuleType("app.rag.summarizer")
65
+ fake_summary.generate_document_summary = lambda filepath, max_sentences=2: "Summary"
66
+ monkeypatch.setitem(__import__("sys").modules, "app.rag.summarizer", fake_summary)
67
+
68
+ monkeypatch.setattr(
69
+ "app.rag.graph_builder.build_graph",
70
+ lambda received_chunks: {"chunks": received_chunks},
71
+ )
72
+ monkeypatch.setattr(
73
+ "app.rag.graph_builder.save_graph",
74
+ lambda graph, user_id, document_id: saved.update(
75
+ {"graph": graph, "user_id": user_id, "document_id": document_id}
76
+ ),
77
+ )
78
+
79
+ _ingest_document(
80
+ document_id=document_id,
81
+ filepath=str(tmp_path / "graph.txt"),
82
+ original_name=document.original_name,
83
+ user_id=user_id,
84
+ )
85
+
86
+ assert saved == {
87
+ "graph": {"chunks": chunks},
88
+ "user_id": user_id,
89
+ "document_id": document_id,
90
+ }
91
+ refreshed = db_session.get(Document, document_id)
92
+ assert refreshed.status == "ready"
93
+ assert refreshed.chunk_count == 1
94
+
95
+
96
+ def test_delete_document_removes_knowledge_graph(client, auth_headers, ready_document, monkeypatch):
97
+ deleted = {}
98
+
99
+ monkeypatch.setattr("app.routes.documents.delete_document_chunks", lambda **kwargs: None)
100
+ monkeypatch.setattr(
101
+ "app.rag.graph_builder.delete_graph",
102
+ lambda user_id, document_id: deleted.update(
103
+ {"user_id": user_id, "document_id": document_id}
104
+ ),
105
+ )
106
+
107
+ response = client.delete(
108
+ f"/api/v1/documents/{ready_document.id}",
109
+ headers=auth_headers,
110
+ )
111
+
112
+ assert response.status_code == 200
113
+ assert deleted["document_id"] == ready_document.id
backend/tests/test_graph_builder.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from app.rag import graph_builder
4
+
5
+
6
+ class FakeEntity:
7
+ def __init__(self, text, label):
8
+ self.text = text
9
+ self.label_ = label
10
+
11
+
12
+ class FakeDoc:
13
+ def __init__(self, entities):
14
+ self.ents = entities
15
+
16
+
17
+ class FakeNlp:
18
+ def __call__(self, text):
19
+ entities = []
20
+ for value, label in (
21
+ ("OpenAI", "ORG"),
22
+ ("Microsoft", "ORG"),
23
+ ("Azure", "PRODUCT"),
24
+ ("Ignored Date", "DATE"),
25
+ ):
26
+ if value in text:
27
+ entities.append(FakeEntity(value, label))
28
+ return FakeDoc(entities)
29
+
30
+
31
+ def test_extract_entities_filters_configured_labels(monkeypatch):
32
+ monkeypatch.setattr(graph_builder, "_nlp", FakeNlp())
33
+
34
+ entities = graph_builder.extract_entities("OpenAI works with Microsoft on Ignored Date")
35
+
36
+ assert {entity.text for entity in entities} == {"OpenAI", "Microsoft"}
37
+ assert {entity.label for entity in entities} == {"ORG"}
38
+
39
+
40
+ def test_build_graph_tracks_entity_edges_and_weights(monkeypatch):
41
+ monkeypatch.setattr(graph_builder, "_nlp", FakeNlp())
42
+ chunks = [
43
+ {
44
+ "text": "OpenAI works with Microsoft.",
45
+ "page": 1,
46
+ "chunk_index": 0,
47
+ },
48
+ {
49
+ "text": "OpenAI and Microsoft use Azure.",
50
+ "page": 2,
51
+ "chunk_index": 1,
52
+ },
53
+ ]
54
+
55
+ graph = graph_builder.build_graph(chunks)
56
+
57
+ openai_id = "ORG:openai"
58
+ microsoft_id = "ORG:microsoft"
59
+ azure_id = "PRODUCT:azure"
60
+ assert graph.nodes[openai_id]["name"] == "OpenAI"
61
+ assert graph.nodes[openai_id]["pages"] == [1, 2]
62
+ assert graph[openai_id][microsoft_id]["weight"] == 2
63
+ assert graph[openai_id][microsoft_id]["pages"] == [1, 2]
64
+ assert graph.has_edge(microsoft_id, azure_id)
65
+
66
+
67
+ def test_save_load_and_delete_graph_roundtrip(tmp_path, monkeypatch):
68
+ monkeypatch.setattr(graph_builder.settings, "GRAPH_PERSIST_DIR", str(tmp_path))
69
+ graph = graph_builder.build_graph([])
70
+ graph.add_node("ORG:openai", name="OpenAI", label="ORG", mentions=1, pages=[1], chunks=[0])
71
+
72
+ path = graph_builder.save_graph(graph, user_id="user-1", document_id="doc-1")
73
+ payload = json.loads(path.read_text(encoding="utf-8"))
74
+ loaded = graph_builder.load_graph(user_id="user-1", document_id="doc-1")
75
+
76
+ assert payload["metadata"]["document_id"] == "doc-1"
77
+ assert loaded.nodes["ORG:openai"]["name"] == "OpenAI"
78
+
79
+ graph_builder.delete_graph(user_id="user-1", document_id="doc-1")
80
+ assert not path.exists()
81
+
82
+
83
+ def test_empty_chunks_produce_empty_graph(monkeypatch):
84
+ monkeypatch.setattr(graph_builder, "_nlp", FakeNlp())
85
+
86
+ graph = graph_builder.build_graph([])
87
+
88
+ assert graph.number_of_nodes() == 0
89
+ assert graph.number_of_edges() == 0
backend/tests/test_graph_retriever.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.rag import graph_builder, graph_retriever
2
+
3
+
4
+ class FakeEntity:
5
+ def __init__(self, text, label):
6
+ self.text = text
7
+ self.label_ = label
8
+
9
+
10
+ class FakeDoc:
11
+ def __init__(self, entities):
12
+ self.ents = entities
13
+
14
+
15
+ class FakeNlp:
16
+ def __call__(self, text):
17
+ entities = []
18
+ for value, label in (
19
+ ("OpenAI", "ORG"),
20
+ ("Microsoft", "ORG"),
21
+ ("Azure", "PRODUCT"),
22
+ ):
23
+ if value in text:
24
+ entities.append(FakeEntity(value, label))
25
+ return FakeDoc(entities)
26
+
27
+
28
+ def _save_sample_graph(tmp_path, monkeypatch, user_id="user-1", document_id="doc-1"):
29
+ monkeypatch.setattr(graph_builder.settings, "GRAPH_PERSIST_DIR", str(tmp_path))
30
+ monkeypatch.setattr(graph_builder, "_nlp", FakeNlp())
31
+ graph = graph_builder.build_graph(
32
+ [
33
+ {
34
+ "text": "OpenAI works with Microsoft.",
35
+ "page": 1,
36
+ "chunk_index": 0,
37
+ },
38
+ {
39
+ "text": "Microsoft deploys Azure.",
40
+ "page": 2,
41
+ "chunk_index": 1,
42
+ },
43
+ ]
44
+ )
45
+ graph_builder.save_graph(graph, user_id=user_id, document_id=document_id)
46
+
47
+
48
+ def test_get_entity_context_returns_one_hop_relationships(tmp_path, monkeypatch):
49
+ _save_sample_graph(tmp_path, monkeypatch)
50
+
51
+ context = graph_retriever.get_entity_context(
52
+ query="How is OpenAI related to Microsoft?",
53
+ user_id="user-1",
54
+ document_id="doc-1",
55
+ )
56
+
57
+ assert "## Knowledge Graph Context" in context
58
+ assert "OpenAI" in context
59
+ assert "Microsoft" in context
60
+ assert "page 1" in context
61
+
62
+
63
+ def test_get_entity_context_returns_empty_for_no_match(tmp_path, monkeypatch):
64
+ _save_sample_graph(tmp_path, monkeypatch)
65
+
66
+ context = graph_retriever.get_entity_context(
67
+ query="What about Google?",
68
+ user_id="user-1",
69
+ document_id="doc-1",
70
+ )
71
+
72
+ assert context == ""
73
+
74
+
75
+ def test_get_entity_context_returns_empty_for_missing_graph(tmp_path, monkeypatch):
76
+ monkeypatch.setattr(graph_builder.settings, "GRAPH_PERSIST_DIR", str(tmp_path))
77
+ monkeypatch.setattr(graph_builder, "_nlp", FakeNlp())
78
+
79
+ context = graph_retriever.get_entity_context(
80
+ query="OpenAI",
81
+ user_id="user-1",
82
+ document_id="missing",
83
+ )
84
+
85
+ assert context == ""
86
+
87
+
88
+ def test_get_entity_context_isolates_users(tmp_path, monkeypatch):
89
+ _save_sample_graph(tmp_path, monkeypatch, user_id="user-1", document_id="doc-1")
90
+
91
+ context = graph_retriever.get_entity_context(
92
+ query="OpenAI",
93
+ user_id="user-2",
94
+ document_id="doc-1",
95
+ )
96
+
97
+ assert context == ""
backend/tests/test_graphrag_agent.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.rag import agent
2
+
3
+
4
+ class FakeMessage:
5
+ content = "Graph answer"
6
+
7
+
8
+ class FakeChoice:
9
+ message = FakeMessage()
10
+
11
+
12
+ class FakeResponse:
13
+ choices = [FakeChoice()]
14
+
15
+
16
+ class FakeClient:
17
+ def __init__(self):
18
+ self.messages = None
19
+
20
+ def chat_completion(self, messages, **kwargs):
21
+ self.messages = messages
22
+ return FakeResponse()
23
+
24
+
25
+ def test_generate_answer_appends_graph_context_without_changing_sources(monkeypatch):
26
+ client = FakeClient()
27
+ chunks = [
28
+ {
29
+ "text": "Vector context",
30
+ "filename": "doc.pdf",
31
+ "page": 1,
32
+ "score": 0.9,
33
+ "confidence": 100.0,
34
+ }
35
+ ]
36
+
37
+ monkeypatch.setattr(agent, "get_llm_client", lambda: client)
38
+ monkeypatch.setattr(agent, "retrieve", lambda **kwargs: chunks)
39
+ monkeypatch.setattr(
40
+ agent,
41
+ "get_entity_context",
42
+ lambda **kwargs: "## Knowledge Graph Context\n- OpenAI is related to Microsoft on page 1.",
43
+ )
44
+
45
+ result = agent.generate_answer("How are OpenAI and Microsoft related?", "user-1", "doc-1")
46
+
47
+ prompt = client.messages[1]["content"]
48
+ assert "Vector context" in prompt
49
+ assert "Knowledge Graph Context" in prompt
50
+ assert result["sources"] == [
51
+ {
52
+ "text": "Vector context",
53
+ "filename": "doc.pdf",
54
+ "page": 1,
55
+ "score": 0.9,
56
+ "confidence": 100.0,
57
+ }
58
+ ]
59
+
60
+
61
+ def test_generate_answer_stream_appends_graph_context(monkeypatch):
62
+ captured = {}
63
+
64
+ class StreamingClient:
65
+ def chat_completion(self, messages, **kwargs):
66
+ captured["messages"] = messages
67
+ return iter([])
68
+
69
+ monkeypatch.setattr(agent, "get_llm_client", lambda: StreamingClient())
70
+ monkeypatch.setattr(
71
+ agent,
72
+ "retrieve",
73
+ lambda **kwargs: [
74
+ {
75
+ "text": "Vector stream context",
76
+ "filename": "doc.pdf",
77
+ "page": 1,
78
+ "score": 0.9,
79
+ "confidence": 100.0,
80
+ }
81
+ ],
82
+ )
83
+ monkeypatch.setattr(
84
+ agent,
85
+ "get_entity_context",
86
+ lambda **kwargs: "## Knowledge Graph Context\n- OpenAI is related to Microsoft on page 1.",
87
+ )
88
+
89
+ events = list(agent.generate_answer_stream("OpenAI Microsoft", "user-1", "doc-1"))
90
+
91
+ assert events[0].startswith("data:")
92
+ assert "Knowledge Graph Context" in captured["messages"][1]["content"]