Spaces:
Running
Running
Merge pull request #255 from Kishalll/feature/graphrag-knowledge-graph
Browse files- .env.example +10 -0
- Dockerfile +3 -2
- backend/app/config.py +16 -0
- backend/app/rag/agent.py +33 -2
- backend/app/rag/graph_builder.py +185 -0
- backend/app/rag/graph_retriever.py +123 -0
- backend/app/routes/documents.py +17 -0
- backend/requirements.txt +3 -0
- backend/tests/test_documents.py +79 -0
- backend/tests/test_graph_builder.py +89 -0
- backend/tests/test_graph_retriever.py +97 -0
- backend/tests/test_graphrag_agent.py +92 -0
.env.example
CHANGED
|
@@ -122,6 +122,16 @@ HF_TOKEN=your_huggingface_token_here
|
|
| 122 |
|
| 123 |
# ── RAG Config (Optional — defaults shown) ───────────
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
# ── ChromaDB (Vector Store) ─────────────────────────────────
|
| 126 |
|
| 127 |
# Directory where ChromaDB persists its vector index to disk.
|
|
|
|
| 122 |
|
| 123 |
# ── RAG Config (Optional — defaults shown) ───────────
|
| 124 |
|
| 125 |
+
# ── Knowledge Graph / GraphRAG (Optional — defaults shown) ─────────────────
|
| 126 |
+
|
| 127 |
+
# Directory where GraphRAG stores per-document knowledge graphs.
|
| 128 |
+
# Optional — defaults to "./data/graphs"
|
| 129 |
+
# GRAPH_PERSIST_DIR=./data/graphs
|
| 130 |
+
|
| 131 |
+
# Maximum number of graph relationships appended to the RAG prompt.
|
| 132 |
+
# Optional — defaults to 12
|
| 133 |
+
# GRAPH_MAX_RELATIONSHIPS=12
|
| 134 |
+
|
| 135 |
# ── ChromaDB (Vector Store) ─────────────────────────────────
|
| 136 |
|
| 137 |
# Directory where ChromaDB persists its vector index to disk.
|
Dockerfile
CHANGED
|
@@ -33,7 +33,8 @@ RUN python -m venv "$VIRTUAL_ENV"
|
|
| 33 |
|
| 34 |
COPY backend/requirements.txt ./requirements.txt
|
| 35 |
RUN pip install --no-cache-dir --upgrade pip setuptools wheel && \
|
| 36 |
-
pip install --no-cache-dir -r requirements.txt
|
|
|
|
| 37 |
|
| 38 |
# --------------------------------------------------------
|
| 39 |
# Stage 3: Runtime image with only app code and artifacts
|
|
@@ -68,7 +69,7 @@ COPY backend/__init__.py ./backend/__init__.py
|
|
| 68 |
COPY --from=frontend-builder /app/frontend/out ./frontend/out
|
| 69 |
|
| 70 |
# Create data directories with proper permissions
|
| 71 |
-
RUN mkdir -p /app/data/uploads /app/data/chroma_db /app/data/huggingface && \
|
| 72 |
chown -R appuser:appuser /app
|
| 73 |
|
| 74 |
# Copy entrypoint
|
|
|
|
| 33 |
|
| 34 |
COPY backend/requirements.txt ./requirements.txt
|
| 35 |
RUN pip install --no-cache-dir --upgrade pip setuptools wheel && \
|
| 36 |
+
pip install --no-cache-dir -r requirements.txt && \
|
| 37 |
+
python -m spacy download en_core_web_sm
|
| 38 |
|
| 39 |
# --------------------------------------------------------
|
| 40 |
# Stage 3: Runtime image with only app code and artifacts
|
|
|
|
| 69 |
COPY --from=frontend-builder /app/frontend/out ./frontend/out
|
| 70 |
|
| 71 |
# Create data directories with proper permissions
|
| 72 |
+
RUN mkdir -p /app/data/uploads /app/data/chroma_db /app/data/graphs /app/data/huggingface && \
|
| 73 |
chown -R appuser:appuser /app
|
| 74 |
|
| 75 |
# Copy entrypoint
|
backend/app/config.py
CHANGED
|
@@ -45,6 +45,22 @@ class Settings(BaseSettings):
|
|
| 45 |
TOP_K_RETRIEVAL: int = 10
|
| 46 |
TOP_K_RERANK: int = 5
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
# ── Embeddings (local HuggingFace model) ─────────────
|
| 49 |
EMBEDDING_MODEL: str = "sentence-transformers/all-MiniLM-L6-v2"
|
| 50 |
EMBEDDING_DIMENSION: int = 384
|
|
|
|
| 45 |
TOP_K_RETRIEVAL: int = 10
|
| 46 |
TOP_K_RERANK: int = 5
|
| 47 |
|
| 48 |
+
# ── Knowledge Graph (GraphRAG) ───────────────────────
|
| 49 |
+
GRAPH_PERSIST_DIR: str = "./data/graphs"
|
| 50 |
+
GRAPH_ENTITY_LABELS: set = {
|
| 51 |
+
"PERSON",
|
| 52 |
+
"ORG",
|
| 53 |
+
"GPE",
|
| 54 |
+
"LOC",
|
| 55 |
+
"PRODUCT",
|
| 56 |
+
"EVENT",
|
| 57 |
+
"WORK_OF_ART",
|
| 58 |
+
"LAW",
|
| 59 |
+
"NORP",
|
| 60 |
+
"FAC",
|
| 61 |
+
}
|
| 62 |
+
GRAPH_MAX_RELATIONSHIPS: int = 12
|
| 63 |
+
|
| 64 |
# ── Embeddings (local HuggingFace model) ─────────────
|
| 65 |
EMBEDDING_MODEL: str = "sentence-transformers/all-MiniLM-L6-v2"
|
| 66 |
EMBEDDING_DIMENSION: int = 384
|
backend/app/rag/agent.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import List, Dict, Any, Optional, Generator
|
|
| 9 |
from huggingface_hub import InferenceClient
|
| 10 |
from app.config import get_settings
|
| 11 |
from app.rag.retriever import retrieve
|
|
|
|
| 12 |
from app.rag.prompts import SYSTEM_PROMPT, RAG_PROMPT_TEMPLATE, GREETING_PROMPT
|
| 13 |
from app.rag.tracing import trace_function
|
| 14 |
|
|
@@ -48,6 +49,26 @@ def build_context(chunks: List[Dict[str, Any]]) -> str:
|
|
| 48 |
return "\n\n---\n\n".join(context_parts)
|
| 49 |
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
def _chat_messages(system: str, user_content: str) -> list:
|
| 52 |
"""Build messages list for chat completion API."""
|
| 53 |
return [
|
|
@@ -108,7 +129,12 @@ def generate_answer(
|
|
| 108 |
|
| 109 |
# ── Build prompt ─────────────────────────────────
|
| 110 |
# Format retrieved chunks into a readable context block, then inject into the RAG prompt template
|
| 111 |
-
context =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
|
| 113 |
messages = _chat_messages(SYSTEM_PROMPT, user_content)
|
| 114 |
|
|
@@ -222,7 +248,12 @@ def generate_answer_stream(
|
|
| 222 |
|
| 223 |
# ── Build prompt ─────────────────────────────────
|
| 224 |
# Format retrieved chunks into a readable context block, then inject into the RAG prompt template
|
| 225 |
-
context =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
|
| 227 |
messages = _chat_messages(SYSTEM_PROMPT, user_content)
|
| 228 |
|
|
|
|
| 9 |
from huggingface_hub import InferenceClient
|
| 10 |
from app.config import get_settings
|
| 11 |
from app.rag.retriever import retrieve
|
| 12 |
+
from app.rag.graph_retriever import get_entity_context
|
| 13 |
from app.rag.prompts import SYSTEM_PROMPT, RAG_PROMPT_TEMPLATE, GREETING_PROMPT
|
| 14 |
from app.rag.tracing import trace_function
|
| 15 |
|
|
|
|
| 49 |
return "\n\n---\n\n".join(context_parts)
|
| 50 |
|
| 51 |
|
| 52 |
+
def build_augmented_context(
|
| 53 |
+
chunks: List[Dict[str, Any]],
|
| 54 |
+
question: str,
|
| 55 |
+
user_id: str,
|
| 56 |
+
document_id: Optional[str] = None,
|
| 57 |
+
) -> str:
|
| 58 |
+
"""Combine vector-retrieved excerpts with GraphRAG relationships."""
|
| 59 |
+
context = build_context(chunks)
|
| 60 |
+
graph_context = get_entity_context(
|
| 61 |
+
query=question,
|
| 62 |
+
user_id=user_id,
|
| 63 |
+
document_id=document_id,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
if not graph_context:
|
| 67 |
+
return context
|
| 68 |
+
|
| 69 |
+
return f"{context}\n\n---\n\n{graph_context}"
|
| 70 |
+
|
| 71 |
+
|
| 72 |
def _chat_messages(system: str, user_content: str) -> list:
|
| 73 |
"""Build messages list for chat completion API."""
|
| 74 |
return [
|
|
|
|
| 129 |
|
| 130 |
# ── Build prompt ─────────────────────────────────
|
| 131 |
# Format retrieved chunks into a readable context block, then inject into the RAG prompt template
|
| 132 |
+
context = build_augmented_context(
|
| 133 |
+
chunks=chunks,
|
| 134 |
+
question=question,
|
| 135 |
+
user_id=user_id,
|
| 136 |
+
document_id=document_id,
|
| 137 |
+
)
|
| 138 |
user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
|
| 139 |
messages = _chat_messages(SYSTEM_PROMPT, user_content)
|
| 140 |
|
|
|
|
| 248 |
|
| 249 |
# ── Build prompt ─────────────────────────────────
|
| 250 |
# Format retrieved chunks into a readable context block, then inject into the RAG prompt template
|
| 251 |
+
context = build_augmented_context(
|
| 252 |
+
chunks=chunks,
|
| 253 |
+
question=question,
|
| 254 |
+
user_id=user_id,
|
| 255 |
+
document_id=document_id,
|
| 256 |
+
)
|
| 257 |
user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
|
| 258 |
messages = _chat_messages(SYSTEM_PROMPT, user_content)
|
| 259 |
|
backend/app/rag/graph_builder.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Knowledge graph construction and persistence for GraphRAG.
|
| 3 |
+
"""
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
import re
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, Iterable, List, Optional
|
| 10 |
+
|
| 11 |
+
import networkx as nx
|
| 12 |
+
|
| 13 |
+
from app.config import get_settings
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
settings = get_settings()
|
| 17 |
+
|
| 18 |
+
_nlp = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass(frozen=True)
|
| 22 |
+
class Entity:
|
| 23 |
+
id: str
|
| 24 |
+
text: str
|
| 25 |
+
label: str
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _safe_id(value: str) -> str:
|
| 29 |
+
safe = re.sub(r"[^A-Za-z0-9_.-]+", "_", value).strip("._")
|
| 30 |
+
return safe or "unknown"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_graph_path(user_id: str, document_id: str) -> Path:
|
| 34 |
+
"""Return the on-disk graph path for one user/document pair."""
|
| 35 |
+
filename = f"{_safe_id(user_id)}_{_safe_id(document_id)}.json"
|
| 36 |
+
return Path(settings.GRAPH_PERSIST_DIR) / filename
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def iter_graph_paths(user_id: str) -> Iterable[Path]:
|
| 40 |
+
"""Yield every persisted graph path for a user."""
|
| 41 |
+
graph_dir = Path(settings.GRAPH_PERSIST_DIR)
|
| 42 |
+
if not graph_dir.exists():
|
| 43 |
+
return []
|
| 44 |
+
|
| 45 |
+
prefix = f"{_safe_id(user_id)}_"
|
| 46 |
+
return sorted(graph_dir.glob(f"{prefix}*.json"))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _get_nlp():
|
| 50 |
+
"""Load the spaCy English NER model lazily."""
|
| 51 |
+
global _nlp
|
| 52 |
+
if _nlp is None:
|
| 53 |
+
import spacy
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
_nlp = spacy.load("en_core_web_sm")
|
| 57 |
+
except OSError as exc:
|
| 58 |
+
raise RuntimeError(
|
| 59 |
+
"spaCy model 'en_core_web_sm' is required for GraphRAG entity extraction. "
|
| 60 |
+
"Install it with: python -m spacy download en_core_web_sm"
|
| 61 |
+
) from exc
|
| 62 |
+
return _nlp
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _entity_id(text: str, label: str) -> str:
|
| 66 |
+
normalized = " ".join(text.split()).casefold()
|
| 67 |
+
return f"{label}:{normalized}"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def extract_entities(text: str) -> List[Entity]:
|
| 71 |
+
"""Extract configured named entities from text."""
|
| 72 |
+
if not text or not text.strip():
|
| 73 |
+
return []
|
| 74 |
+
|
| 75 |
+
doc = _get_nlp()(text)
|
| 76 |
+
entities: Dict[str, Entity] = {}
|
| 77 |
+
|
| 78 |
+
for ent in doc.ents:
|
| 79 |
+
value = " ".join(ent.text.split()).strip()
|
| 80 |
+
if not value or ent.label_ not in settings.GRAPH_ENTITY_LABELS:
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
entity_id = _entity_id(value, ent.label_)
|
| 84 |
+
entities.setdefault(
|
| 85 |
+
entity_id,
|
| 86 |
+
Entity(id=entity_id, text=value, label=ent.label_),
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
return list(entities.values())
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def build_graph(chunks: List[Dict[str, Any]]) -> nx.Graph:
|
| 93 |
+
"""Build an entity co-occurrence graph from document chunks."""
|
| 94 |
+
graph = nx.Graph()
|
| 95 |
+
|
| 96 |
+
for chunk in chunks:
|
| 97 |
+
text = chunk.get("text", "")
|
| 98 |
+
page = chunk.get("page")
|
| 99 |
+
chunk_index = chunk.get("chunk_index")
|
| 100 |
+
entities = extract_entities(text)
|
| 101 |
+
|
| 102 |
+
for entity in entities:
|
| 103 |
+
if graph.has_node(entity.id):
|
| 104 |
+
graph.nodes[entity.id]["mentions"] += 1
|
| 105 |
+
graph.nodes[entity.id]["pages"].add(page)
|
| 106 |
+
graph.nodes[entity.id]["chunks"].add(chunk_index)
|
| 107 |
+
else:
|
| 108 |
+
graph.add_node(
|
| 109 |
+
entity.id,
|
| 110 |
+
name=entity.text,
|
| 111 |
+
label=entity.label,
|
| 112 |
+
mentions=1,
|
| 113 |
+
pages={page},
|
| 114 |
+
chunks={chunk_index},
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
for left_index, left in enumerate(entities):
|
| 118 |
+
for right in entities[left_index + 1:]:
|
| 119 |
+
if graph.has_edge(left.id, right.id):
|
| 120 |
+
graph[left.id][right.id]["weight"] += 1
|
| 121 |
+
graph[left.id][right.id]["pages"].add(page)
|
| 122 |
+
graph[left.id][right.id]["chunks"].add(chunk_index)
|
| 123 |
+
else:
|
| 124 |
+
graph.add_edge(
|
| 125 |
+
left.id,
|
| 126 |
+
right.id,
|
| 127 |
+
weight=1,
|
| 128 |
+
pages={page},
|
| 129 |
+
chunks={chunk_index},
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
_convert_sets_for_json(graph)
|
| 133 |
+
return graph
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _convert_sets_for_json(graph: nx.Graph) -> None:
|
| 137 |
+
for _, data in graph.nodes(data=True):
|
| 138 |
+
data["pages"] = sorted(item for item in data.get("pages", []) if item is not None)
|
| 139 |
+
data["chunks"] = sorted(item for item in data.get("chunks", []) if item is not None)
|
| 140 |
+
|
| 141 |
+
for _, _, data in graph.edges(data=True):
|
| 142 |
+
data["pages"] = sorted(item for item in data.get("pages", []) if item is not None)
|
| 143 |
+
data["chunks"] = sorted(item for item in data.get("chunks", []) if item is not None)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def save_graph(graph: nx.Graph, user_id: str, document_id: str) -> Path:
|
| 147 |
+
"""Persist a graph to disk as node-link JSON."""
|
| 148 |
+
graph_path = get_graph_path(user_id, document_id)
|
| 149 |
+
graph_path.parent.mkdir(parents=True, exist_ok=True)
|
| 150 |
+
|
| 151 |
+
data = nx.node_link_data(graph)
|
| 152 |
+
data["metadata"] = {
|
| 153 |
+
"user_id": user_id,
|
| 154 |
+
"document_id": document_id,
|
| 155 |
+
"node_count": graph.number_of_nodes(),
|
| 156 |
+
"edge_count": graph.number_of_edges(),
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
graph_path.write_text(json.dumps(data, ensure_ascii=True, indent=2), encoding="utf-8")
|
| 160 |
+
logger.info(
|
| 161 |
+
"Saved knowledge graph for document %s with %s nodes and %s edges",
|
| 162 |
+
document_id,
|
| 163 |
+
graph.number_of_nodes(),
|
| 164 |
+
graph.number_of_edges(),
|
| 165 |
+
)
|
| 166 |
+
return graph_path
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def load_graph(user_id: str, document_id: str) -> Optional[nx.Graph]:
|
| 170 |
+
"""Load a persisted graph for one user/document pair."""
|
| 171 |
+
return load_graph_path(get_graph_path(user_id, document_id))
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def load_graph_path(graph_path: Path) -> Optional[nx.Graph]:
|
| 175 |
+
"""Load a graph from a concrete JSON path."""
|
| 176 |
+
if not graph_path.exists():
|
| 177 |
+
return None
|
| 178 |
+
|
| 179 |
+
data = json.loads(graph_path.read_text(encoding="utf-8"))
|
| 180 |
+
return nx.node_link_graph(data)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def delete_graph(user_id: str, document_id: str) -> None:
|
| 184 |
+
"""Delete a persisted graph file if it exists."""
|
| 185 |
+
get_graph_path(user_id, document_id).unlink(missing_ok=True)
|
backend/app/rag/graph_retriever.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Knowledge graph retrieval for augmenting RAG context.
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
| 6 |
+
|
| 7 |
+
import networkx as nx
|
| 8 |
+
|
| 9 |
+
from app.config import get_settings
|
| 10 |
+
from app.rag.graph_builder import (
|
| 11 |
+
extract_entities,
|
| 12 |
+
iter_graph_paths,
|
| 13 |
+
load_graph,
|
| 14 |
+
load_graph_path,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
settings = get_settings()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _candidate_graphs(user_id: str, document_id: Optional[str]) -> Iterable[nx.Graph]:
|
| 22 |
+
if document_id:
|
| 23 |
+
graph = load_graph(user_id, document_id)
|
| 24 |
+
return [graph] if graph is not None else []
|
| 25 |
+
|
| 26 |
+
graphs = []
|
| 27 |
+
for path in iter_graph_paths(user_id):
|
| 28 |
+
graph = load_graph_path(path)
|
| 29 |
+
if graph is not None:
|
| 30 |
+
graphs.append(graph)
|
| 31 |
+
return graphs
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _node_name(graph: nx.Graph, node_id: str) -> str:
|
| 35 |
+
return graph.nodes[node_id].get("name", node_id.split(":", 1)[-1])
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _match_query_nodes(graph: nx.Graph, query: str) -> Set[str]:
|
| 39 |
+
query_entities = extract_entities(query)
|
| 40 |
+
matched = {entity.id for entity in query_entities if graph.has_node(entity.id)}
|
| 41 |
+
|
| 42 |
+
if matched:
|
| 43 |
+
return matched
|
| 44 |
+
|
| 45 |
+
query_text = query.casefold()
|
| 46 |
+
for node_id, data in graph.nodes(data=True):
|
| 47 |
+
name = data.get("name", "").casefold()
|
| 48 |
+
if name and name in query_text:
|
| 49 |
+
matched.add(node_id)
|
| 50 |
+
|
| 51 |
+
return matched
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _format_pages(pages: List[int]) -> str:
|
| 55 |
+
if not pages:
|
| 56 |
+
return "unknown pages"
|
| 57 |
+
if len(pages) == 1:
|
| 58 |
+
return f"page {pages[0]}"
|
| 59 |
+
return "pages " + ", ".join(str(page) for page in pages[:4])
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _relationship_key(left: str, right: str) -> Tuple[str, str]:
|
| 63 |
+
return tuple(sorted((left, right)))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_entity_context(
|
| 67 |
+
query: str,
|
| 68 |
+
user_id: str,
|
| 69 |
+
document_id: Optional[str] = None,
|
| 70 |
+
) -> str:
|
| 71 |
+
"""Return compact graph relationship context relevant to the query."""
|
| 72 |
+
relationships: Dict[Tuple[str, str], Dict[str, object]] = {}
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
graphs = _candidate_graphs(user_id=user_id, document_id=document_id)
|
| 76 |
+
for graph in graphs:
|
| 77 |
+
matched_nodes = _match_query_nodes(graph, query)
|
| 78 |
+
|
| 79 |
+
for node_id in matched_nodes:
|
| 80 |
+
neighbors = sorted(
|
| 81 |
+
graph.neighbors(node_id),
|
| 82 |
+
key=lambda neighbor: graph[node_id][neighbor].get("weight", 0),
|
| 83 |
+
reverse=True,
|
| 84 |
+
)
|
| 85 |
+
for neighbor_id in neighbors:
|
| 86 |
+
edge = graph[node_id][neighbor_id]
|
| 87 |
+
left = _node_name(graph, node_id)
|
| 88 |
+
right = _node_name(graph, neighbor_id)
|
| 89 |
+
key = _relationship_key(left.casefold(), right.casefold())
|
| 90 |
+
existing = relationships.setdefault(
|
| 91 |
+
key,
|
| 92 |
+
{
|
| 93 |
+
"left": left,
|
| 94 |
+
"right": right,
|
| 95 |
+
"weight": 0,
|
| 96 |
+
"pages": set(),
|
| 97 |
+
},
|
| 98 |
+
)
|
| 99 |
+
existing["weight"] = int(existing["weight"]) + int(edge.get("weight", 1))
|
| 100 |
+
existing["pages"].update(edge.get("pages", []))
|
| 101 |
+
except Exception as exc:
|
| 102 |
+
logger.warning("GraphRAG context retrieval failed: %s", exc)
|
| 103 |
+
return ""
|
| 104 |
+
|
| 105 |
+
if not relationships:
|
| 106 |
+
return ""
|
| 107 |
+
|
| 108 |
+
ranked = sorted(
|
| 109 |
+
relationships.values(),
|
| 110 |
+
key=lambda item: int(item["weight"]),
|
| 111 |
+
reverse=True,
|
| 112 |
+
)[: settings.GRAPH_MAX_RELATIONSHIPS]
|
| 113 |
+
|
| 114 |
+
lines = ["## Knowledge Graph Context"]
|
| 115 |
+
for item in ranked:
|
| 116 |
+
pages = sorted(item["pages"])
|
| 117 |
+
lines.append(
|
| 118 |
+
f"- {item['left']} is related to {item['right']} "
|
| 119 |
+
f"through document co-occurrence on {_format_pages(pages)} "
|
| 120 |
+
f"(strength: {item['weight']})."
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
return "\n".join(lines)
|
backend/app/routes/documents.py
CHANGED
|
@@ -172,6 +172,15 @@ def _ingest_document(document_id: str, filepath: str, original_name: str, user_i
|
|
| 172 |
db.commit()
|
| 173 |
return
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
# Store embeddings in ChromaDB
|
| 176 |
chunk_count = store_chunks(
|
| 177 |
chunks=chunks,
|
|
@@ -629,6 +638,14 @@ def delete_document(
|
|
| 629 |
except Exception as e:
|
| 630 |
logger.warning(f"Error deleting vectors: {e}")
|
| 631 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 632 |
# Delete from database (cascades to chat messages)
|
| 633 |
db.delete(doc)
|
| 634 |
db.commit()
|
|
|
|
| 172 |
db.commit()
|
| 173 |
return
|
| 174 |
|
| 175 |
+
# Build and persist a lightweight entity co-occurrence graph for GraphRAG.
|
| 176 |
+
try:
|
| 177 |
+
from app.rag.graph_builder import build_graph, save_graph
|
| 178 |
+
|
| 179 |
+
graph = build_graph(chunks)
|
| 180 |
+
save_graph(graph, user_id=user_id, document_id=document_id)
|
| 181 |
+
except Exception as e:
|
| 182 |
+
logger.warning(f"Could not build knowledge graph for document {document_id}: {e}")
|
| 183 |
+
|
| 184 |
# Store embeddings in ChromaDB
|
| 185 |
chunk_count = store_chunks(
|
| 186 |
chunks=chunks,
|
|
|
|
| 638 |
except Exception as e:
|
| 639 |
logger.warning(f"Error deleting vectors: {e}")
|
| 640 |
|
| 641 |
+
# Delete persisted knowledge graph
|
| 642 |
+
try:
|
| 643 |
+
from app.rag.graph_builder import delete_graph
|
| 644 |
+
|
| 645 |
+
delete_graph(user_id=user.id, document_id=document_id)
|
| 646 |
+
except Exception as e:
|
| 647 |
+
logger.warning(f"Error deleting knowledge graph: {e}")
|
| 648 |
+
|
| 649 |
# Delete from database (cascades to chat messages)
|
| 650 |
db.delete(doc)
|
| 651 |
db.commit()
|
backend/requirements.txt
CHANGED
|
@@ -41,6 +41,9 @@ transformers
|
|
| 41 |
|
| 42 |
# Vector Database
|
| 43 |
chromadb
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
# LLM Inference
|
| 46 |
huggingface-hub
|
|
|
|
| 41 |
|
| 42 |
# Vector Database
|
| 43 |
chromadb
|
| 44 |
+
networkx>=3.3
|
| 45 |
+
spacy>=3.7
|
| 46 |
+
neo4j>=5.0
|
| 47 |
|
| 48 |
# LLM Inference
|
| 49 |
huggingface-hub
|
backend/tests/test_documents.py
CHANGED
|
@@ -1,3 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
def test_api_health(client):
|
| 2 |
response = client.get("/api/health")
|
| 3 |
|
|
@@ -32,3 +38,76 @@ def test_upload_rejects_unsupported_extension_before_deep_validation(client, aut
|
|
| 32 |
|
| 33 |
assert response.status_code == 400
|
| 34 |
assert "not supported" in response.json()["detail"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import types
|
| 2 |
+
|
| 3 |
+
from app.models import Document
|
| 4 |
+
from app.routes.documents import _ingest_document
|
| 5 |
+
|
| 6 |
+
|
| 7 |
def test_api_health(client):
|
| 8 |
response = client.get("/api/health")
|
| 9 |
|
|
|
|
| 38 |
|
| 39 |
assert response.status_code == 400
|
| 40 |
assert "not supported" in response.json()["detail"]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def test_ingest_document_builds_and_saves_graph(db_session, monkeypatch, tmp_path, user):
|
| 44 |
+
document = Document(
|
| 45 |
+
user_id=user.id,
|
| 46 |
+
filename="graph.txt",
|
| 47 |
+
original_name="graph.txt",
|
| 48 |
+
file_size=128,
|
| 49 |
+
status="pending",
|
| 50 |
+
)
|
| 51 |
+
db_session.add(document)
|
| 52 |
+
db_session.commit()
|
| 53 |
+
db_session.refresh(document)
|
| 54 |
+
user_id = user.id
|
| 55 |
+
document_id = document.id
|
| 56 |
+
chunks = [{"text": "OpenAI works with Microsoft.", "page": 1, "chunk_index": 0}]
|
| 57 |
+
saved = {}
|
| 58 |
+
|
| 59 |
+
monkeypatch.setattr("app.routes.documents.get_page_count", lambda filepath: 1)
|
| 60 |
+
monkeypatch.setattr("app.routes.documents.chunk_document", lambda filepath: chunks)
|
| 61 |
+
monkeypatch.setattr("app.routes.documents.store_chunks", lambda **kwargs: len(chunks))
|
| 62 |
+
monkeypatch.setattr("app.database.SessionLocal", lambda: db_session)
|
| 63 |
+
|
| 64 |
+
fake_summary = types.ModuleType("app.rag.summarizer")
|
| 65 |
+
fake_summary.generate_document_summary = lambda filepath, max_sentences=2: "Summary"
|
| 66 |
+
monkeypatch.setitem(__import__("sys").modules, "app.rag.summarizer", fake_summary)
|
| 67 |
+
|
| 68 |
+
monkeypatch.setattr(
|
| 69 |
+
"app.rag.graph_builder.build_graph",
|
| 70 |
+
lambda received_chunks: {"chunks": received_chunks},
|
| 71 |
+
)
|
| 72 |
+
monkeypatch.setattr(
|
| 73 |
+
"app.rag.graph_builder.save_graph",
|
| 74 |
+
lambda graph, user_id, document_id: saved.update(
|
| 75 |
+
{"graph": graph, "user_id": user_id, "document_id": document_id}
|
| 76 |
+
),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
_ingest_document(
|
| 80 |
+
document_id=document_id,
|
| 81 |
+
filepath=str(tmp_path / "graph.txt"),
|
| 82 |
+
original_name=document.original_name,
|
| 83 |
+
user_id=user_id,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
assert saved == {
|
| 87 |
+
"graph": {"chunks": chunks},
|
| 88 |
+
"user_id": user_id,
|
| 89 |
+
"document_id": document_id,
|
| 90 |
+
}
|
| 91 |
+
refreshed = db_session.get(Document, document_id)
|
| 92 |
+
assert refreshed.status == "ready"
|
| 93 |
+
assert refreshed.chunk_count == 1
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def test_delete_document_removes_knowledge_graph(client, auth_headers, ready_document, monkeypatch):
|
| 97 |
+
deleted = {}
|
| 98 |
+
|
| 99 |
+
monkeypatch.setattr("app.routes.documents.delete_document_chunks", lambda **kwargs: None)
|
| 100 |
+
monkeypatch.setattr(
|
| 101 |
+
"app.rag.graph_builder.delete_graph",
|
| 102 |
+
lambda user_id, document_id: deleted.update(
|
| 103 |
+
{"user_id": user_id, "document_id": document_id}
|
| 104 |
+
),
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
response = client.delete(
|
| 108 |
+
f"/api/v1/documents/{ready_document.id}",
|
| 109 |
+
headers=auth_headers,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
assert response.status_code == 200
|
| 113 |
+
assert deleted["document_id"] == ready_document.id
|
backend/tests/test_graph_builder.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
from app.rag import graph_builder
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class FakeEntity:
|
| 7 |
+
def __init__(self, text, label):
|
| 8 |
+
self.text = text
|
| 9 |
+
self.label_ = label
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class FakeDoc:
|
| 13 |
+
def __init__(self, entities):
|
| 14 |
+
self.ents = entities
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class FakeNlp:
|
| 18 |
+
def __call__(self, text):
|
| 19 |
+
entities = []
|
| 20 |
+
for value, label in (
|
| 21 |
+
("OpenAI", "ORG"),
|
| 22 |
+
("Microsoft", "ORG"),
|
| 23 |
+
("Azure", "PRODUCT"),
|
| 24 |
+
("Ignored Date", "DATE"),
|
| 25 |
+
):
|
| 26 |
+
if value in text:
|
| 27 |
+
entities.append(FakeEntity(value, label))
|
| 28 |
+
return FakeDoc(entities)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def test_extract_entities_filters_configured_labels(monkeypatch):
|
| 32 |
+
monkeypatch.setattr(graph_builder, "_nlp", FakeNlp())
|
| 33 |
+
|
| 34 |
+
entities = graph_builder.extract_entities("OpenAI works with Microsoft on Ignored Date")
|
| 35 |
+
|
| 36 |
+
assert {entity.text for entity in entities} == {"OpenAI", "Microsoft"}
|
| 37 |
+
assert {entity.label for entity in entities} == {"ORG"}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_build_graph_tracks_entity_edges_and_weights(monkeypatch):
|
| 41 |
+
monkeypatch.setattr(graph_builder, "_nlp", FakeNlp())
|
| 42 |
+
chunks = [
|
| 43 |
+
{
|
| 44 |
+
"text": "OpenAI works with Microsoft.",
|
| 45 |
+
"page": 1,
|
| 46 |
+
"chunk_index": 0,
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"text": "OpenAI and Microsoft use Azure.",
|
| 50 |
+
"page": 2,
|
| 51 |
+
"chunk_index": 1,
|
| 52 |
+
},
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
graph = graph_builder.build_graph(chunks)
|
| 56 |
+
|
| 57 |
+
openai_id = "ORG:openai"
|
| 58 |
+
microsoft_id = "ORG:microsoft"
|
| 59 |
+
azure_id = "PRODUCT:azure"
|
| 60 |
+
assert graph.nodes[openai_id]["name"] == "OpenAI"
|
| 61 |
+
assert graph.nodes[openai_id]["pages"] == [1, 2]
|
| 62 |
+
assert graph[openai_id][microsoft_id]["weight"] == 2
|
| 63 |
+
assert graph[openai_id][microsoft_id]["pages"] == [1, 2]
|
| 64 |
+
assert graph.has_edge(microsoft_id, azure_id)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def test_save_load_and_delete_graph_roundtrip(tmp_path, monkeypatch):
|
| 68 |
+
monkeypatch.setattr(graph_builder.settings, "GRAPH_PERSIST_DIR", str(tmp_path))
|
| 69 |
+
graph = graph_builder.build_graph([])
|
| 70 |
+
graph.add_node("ORG:openai", name="OpenAI", label="ORG", mentions=1, pages=[1], chunks=[0])
|
| 71 |
+
|
| 72 |
+
path = graph_builder.save_graph(graph, user_id="user-1", document_id="doc-1")
|
| 73 |
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 74 |
+
loaded = graph_builder.load_graph(user_id="user-1", document_id="doc-1")
|
| 75 |
+
|
| 76 |
+
assert payload["metadata"]["document_id"] == "doc-1"
|
| 77 |
+
assert loaded.nodes["ORG:openai"]["name"] == "OpenAI"
|
| 78 |
+
|
| 79 |
+
graph_builder.delete_graph(user_id="user-1", document_id="doc-1")
|
| 80 |
+
assert not path.exists()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def test_empty_chunks_produce_empty_graph(monkeypatch):
|
| 84 |
+
monkeypatch.setattr(graph_builder, "_nlp", FakeNlp())
|
| 85 |
+
|
| 86 |
+
graph = graph_builder.build_graph([])
|
| 87 |
+
|
| 88 |
+
assert graph.number_of_nodes() == 0
|
| 89 |
+
assert graph.number_of_edges() == 0
|
backend/tests/test_graph_retriever.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from app.rag import graph_builder, graph_retriever
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class FakeEntity:
|
| 5 |
+
def __init__(self, text, label):
|
| 6 |
+
self.text = text
|
| 7 |
+
self.label_ = label
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class FakeDoc:
|
| 11 |
+
def __init__(self, entities):
|
| 12 |
+
self.ents = entities
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FakeNlp:
|
| 16 |
+
def __call__(self, text):
|
| 17 |
+
entities = []
|
| 18 |
+
for value, label in (
|
| 19 |
+
("OpenAI", "ORG"),
|
| 20 |
+
("Microsoft", "ORG"),
|
| 21 |
+
("Azure", "PRODUCT"),
|
| 22 |
+
):
|
| 23 |
+
if value in text:
|
| 24 |
+
entities.append(FakeEntity(value, label))
|
| 25 |
+
return FakeDoc(entities)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _save_sample_graph(tmp_path, monkeypatch, user_id="user-1", document_id="doc-1"):
|
| 29 |
+
monkeypatch.setattr(graph_builder.settings, "GRAPH_PERSIST_DIR", str(tmp_path))
|
| 30 |
+
monkeypatch.setattr(graph_builder, "_nlp", FakeNlp())
|
| 31 |
+
graph = graph_builder.build_graph(
|
| 32 |
+
[
|
| 33 |
+
{
|
| 34 |
+
"text": "OpenAI works with Microsoft.",
|
| 35 |
+
"page": 1,
|
| 36 |
+
"chunk_index": 0,
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"text": "Microsoft deploys Azure.",
|
| 40 |
+
"page": 2,
|
| 41 |
+
"chunk_index": 1,
|
| 42 |
+
},
|
| 43 |
+
]
|
| 44 |
+
)
|
| 45 |
+
graph_builder.save_graph(graph, user_id=user_id, document_id=document_id)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def test_get_entity_context_returns_one_hop_relationships(tmp_path, monkeypatch):
|
| 49 |
+
_save_sample_graph(tmp_path, monkeypatch)
|
| 50 |
+
|
| 51 |
+
context = graph_retriever.get_entity_context(
|
| 52 |
+
query="How is OpenAI related to Microsoft?",
|
| 53 |
+
user_id="user-1",
|
| 54 |
+
document_id="doc-1",
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
assert "## Knowledge Graph Context" in context
|
| 58 |
+
assert "OpenAI" in context
|
| 59 |
+
assert "Microsoft" in context
|
| 60 |
+
assert "page 1" in context
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def test_get_entity_context_returns_empty_for_no_match(tmp_path, monkeypatch):
|
| 64 |
+
_save_sample_graph(tmp_path, monkeypatch)
|
| 65 |
+
|
| 66 |
+
context = graph_retriever.get_entity_context(
|
| 67 |
+
query="What about Google?",
|
| 68 |
+
user_id="user-1",
|
| 69 |
+
document_id="doc-1",
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
assert context == ""
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def test_get_entity_context_returns_empty_for_missing_graph(tmp_path, monkeypatch):
|
| 76 |
+
monkeypatch.setattr(graph_builder.settings, "GRAPH_PERSIST_DIR", str(tmp_path))
|
| 77 |
+
monkeypatch.setattr(graph_builder, "_nlp", FakeNlp())
|
| 78 |
+
|
| 79 |
+
context = graph_retriever.get_entity_context(
|
| 80 |
+
query="OpenAI",
|
| 81 |
+
user_id="user-1",
|
| 82 |
+
document_id="missing",
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
assert context == ""
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def test_get_entity_context_isolates_users(tmp_path, monkeypatch):
|
| 89 |
+
_save_sample_graph(tmp_path, monkeypatch, user_id="user-1", document_id="doc-1")
|
| 90 |
+
|
| 91 |
+
context = graph_retriever.get_entity_context(
|
| 92 |
+
query="OpenAI",
|
| 93 |
+
user_id="user-2",
|
| 94 |
+
document_id="doc-1",
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
assert context == ""
|
backend/tests/test_graphrag_agent.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from app.rag import agent
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class FakeMessage:
|
| 5 |
+
content = "Graph answer"
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class FakeChoice:
|
| 9 |
+
message = FakeMessage()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class FakeResponse:
|
| 13 |
+
choices = [FakeChoice()]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class FakeClient:
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.messages = None
|
| 19 |
+
|
| 20 |
+
def chat_completion(self, messages, **kwargs):
|
| 21 |
+
self.messages = messages
|
| 22 |
+
return FakeResponse()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def test_generate_answer_appends_graph_context_without_changing_sources(monkeypatch):
|
| 26 |
+
client = FakeClient()
|
| 27 |
+
chunks = [
|
| 28 |
+
{
|
| 29 |
+
"text": "Vector context",
|
| 30 |
+
"filename": "doc.pdf",
|
| 31 |
+
"page": 1,
|
| 32 |
+
"score": 0.9,
|
| 33 |
+
"confidence": 100.0,
|
| 34 |
+
}
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
monkeypatch.setattr(agent, "get_llm_client", lambda: client)
|
| 38 |
+
monkeypatch.setattr(agent, "retrieve", lambda **kwargs: chunks)
|
| 39 |
+
monkeypatch.setattr(
|
| 40 |
+
agent,
|
| 41 |
+
"get_entity_context",
|
| 42 |
+
lambda **kwargs: "## Knowledge Graph Context\n- OpenAI is related to Microsoft on page 1.",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
result = agent.generate_answer("How are OpenAI and Microsoft related?", "user-1", "doc-1")
|
| 46 |
+
|
| 47 |
+
prompt = client.messages[1]["content"]
|
| 48 |
+
assert "Vector context" in prompt
|
| 49 |
+
assert "Knowledge Graph Context" in prompt
|
| 50 |
+
assert result["sources"] == [
|
| 51 |
+
{
|
| 52 |
+
"text": "Vector context",
|
| 53 |
+
"filename": "doc.pdf",
|
| 54 |
+
"page": 1,
|
| 55 |
+
"score": 0.9,
|
| 56 |
+
"confidence": 100.0,
|
| 57 |
+
}
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def test_generate_answer_stream_appends_graph_context(monkeypatch):
|
| 62 |
+
captured = {}
|
| 63 |
+
|
| 64 |
+
class StreamingClient:
|
| 65 |
+
def chat_completion(self, messages, **kwargs):
|
| 66 |
+
captured["messages"] = messages
|
| 67 |
+
return iter([])
|
| 68 |
+
|
| 69 |
+
monkeypatch.setattr(agent, "get_llm_client", lambda: StreamingClient())
|
| 70 |
+
monkeypatch.setattr(
|
| 71 |
+
agent,
|
| 72 |
+
"retrieve",
|
| 73 |
+
lambda **kwargs: [
|
| 74 |
+
{
|
| 75 |
+
"text": "Vector stream context",
|
| 76 |
+
"filename": "doc.pdf",
|
| 77 |
+
"page": 1,
|
| 78 |
+
"score": 0.9,
|
| 79 |
+
"confidence": 100.0,
|
| 80 |
+
}
|
| 81 |
+
],
|
| 82 |
+
)
|
| 83 |
+
monkeypatch.setattr(
|
| 84 |
+
agent,
|
| 85 |
+
"get_entity_context",
|
| 86 |
+
lambda **kwargs: "## Knowledge Graph Context\n- OpenAI is related to Microsoft on page 1.",
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
events = list(agent.generate_answer_stream("OpenAI Microsoft", "user-1", "doc-1"))
|
| 90 |
+
|
| 91 |
+
assert events[0].startswith("data:")
|
| 92 |
+
assert "Knowledge Graph Context" in captured["messages"][1]["content"]
|