GraphResearcher / scripts /fix_graph_guided_retrieval_quality.py
yugbirla's picture
Sync GraphRAG fusion quality cleanup and evaluation files
b7d0804
Raw
History Blame Contribute Delete
7.58 kB
from pathlib import Path
path = Path("app/graph/graph_guided_retriever.py")
path.write_text(r'''
from typing import Dict, Any, List, Optional
import re
from app.graph.graph_context_service import build_graph_context_for_query
from app.storage.processed_storage import read_processed_chunks
from app.graph.graph_quality import (
is_low_quality_chunk_text,
is_meta_showcase_chunk_text
)
def get_value(obj, key: str, default=None):
if isinstance(obj, dict):
return obj.get(key, default)
return getattr(obj, key, default)
def normalize_chunk_id(value) -> str:
if value is None:
return ""
return str(value)
def build_chunk_lookup(chunks: List[Any]) -> Dict[str, Any]:
lookup = {}
for index, chunk in enumerate(chunks):
chunk_id = (
get_value(chunk, "chunk_id")
or get_value(chunk, "id")
or f"chunk_{index}"
)
lookup[normalize_chunk_id(chunk_id)] = chunk
return lookup
def extract_text_preview(chunk, max_chars: int = 700) -> str:
text = (
get_value(chunk, "content")
or get_value(chunk, "text")
or ""
)
text = str(text).replace("\\n", " ").strip()
if len(text) > max_chars:
return text[:max_chars] + "..."
return text
def tokenize(text: str) -> List[str]:
return re.findall(r"[a-zA-Z0-9_]+", str(text or "").lower())
def query_text_relevance(query: str, text: str) -> float:
"""
Adds text-level relevance so graph retrieval does not rank a chunk
only because it has connected entities.
"""
query_terms = [
term for term in tokenize(query)
if term not in {"what", "is", "are", "the", "a", "an", "of", "to", "and", "why", "how"}
]
text_lower = str(text or "").lower()
text_tokens = set(tokenize(text))
score = 0.0
for term in query_terms:
if term in text_tokens:
score += 4.0
elif len(term) >= 4 and term in text_lower:
score += 1.5
# Definition questions should prefer chunks with definition-like language.
if "what" in query.lower() and "rag" in query.lower():
definition_markers = [
"rag is",
"rag stands for",
"retrieval-augmented generation",
"retrieval augmented generation",
"adds a retrieval step",
"before generation",
"document corpus"
]
for marker in definition_markers:
if marker in text_lower:
score += 5.0
return score
def score_graph_chunks(
graph_context: Dict[str, Any]
) -> Dict[str, Dict[str, Any]]:
chunk_scores: Dict[str, Dict[str, Any]] = {}
matched_entities = graph_context.get("matched_entities", [])
matched_relations = graph_context.get("matched_relations", [])
for entity in matched_entities:
mention_count = entity.get("mention_count", 1) or 1
base_score = 3.0 + min(mention_count, 10) * 0.2
for chunk_id in entity.get("chunk_ids", []):
cid = normalize_chunk_id(chunk_id)
if not cid:
continue
if cid not in chunk_scores:
chunk_scores[cid] = {
"score": 0.0,
"matched_entities": [],
"matched_relations": []
}
chunk_scores[cid]["score"] += base_score
chunk_scores[cid]["matched_entities"].append(entity.get("name"))
for relation in matched_relations:
weight = relation.get("weight", 1) or 1
base_score = 2.0 + min(weight, 10) * 0.3
relation_label = (
f'{relation.get("source")} '
f'--{relation.get("relation_type")}--> '
f'{relation.get("target")}'
)
for chunk_id in relation.get("chunk_ids", []):
cid = normalize_chunk_id(chunk_id)
if not cid:
continue
if cid not in chunk_scores:
chunk_scores[cid] = {
"score": 0.0,
"matched_entities": [],
"matched_relations": []
}
chunk_scores[cid]["score"] += base_score
chunk_scores[cid]["matched_relations"].append(relation_label)
return chunk_scores
def graph_guided_retrieve(
document_id: Optional[str],
query: str,
graph_entity_limit: int = 8,
top_k: int = 5
) -> Dict[str, Any]:
if not document_id:
return {
"status": "failed",
"message": "document_id is required.",
"results": []
}
chunks = read_processed_chunks(document_id)
if chunks is None:
return {
"status": "failed",
"message": "No processed chunks found. Upload/process the document first.",
"document_id": document_id,
"results": []
}
graph_context = build_graph_context_for_query(
document_id=document_id,
query=query,
limit=graph_entity_limit
)
if not graph_context.get("graph_available"):
return {
"status": "failed",
"message": graph_context.get("reason", "Graph context not available."),
"document_id": document_id,
"graph_context": graph_context,
"results": []
}
chunk_lookup = build_chunk_lookup(chunks)
chunk_scores = score_graph_chunks(graph_context)
candidate_results = []
for chunk_id, info in chunk_scores.items():
chunk = chunk_lookup.get(chunk_id)
if chunk is None:
continue
text_preview = extract_text_preview(chunk)
if is_low_quality_chunk_text(text_preview):
continue
if is_meta_showcase_chunk_text(text_preview):
continue
final_score = info["score"] + query_text_relevance(query, text_preview)
candidate_results.append(
{
"chunk_id": chunk_id,
"graph_score": round(final_score, 4),
"page_number": get_value(chunk, "page_number"),
"source_file_name": (
get_value(chunk, "source_file_name")
or get_value(chunk, "file_name")
or get_value(chunk, "filename")
),
"matched_entities": sorted(set(info["matched_entities"])),
"matched_relations": sorted(set(info["matched_relations"])),
"text_preview": text_preview
}
)
candidate_results = sorted(
candidate_results,
key=lambda item: item["graph_score"],
reverse=True
)
results = []
for rank, item in enumerate(candidate_results[:top_k], start=1):
item["rank"] = rank
results.append(item)
return {
"status": "success",
"document_id": document_id,
"query": query,
"graph_available": True,
"graph_entity_limit": graph_entity_limit,
"top_k": top_k,
"matched_entity_count": len(graph_context.get("matched_entities", [])),
"matched_relation_count": len(graph_context.get("matched_relations", [])),
"returned_chunks": len(results),
"matched_entities": graph_context.get("matched_entities", []),
"matched_relations": graph_context.get("matched_relations", []),
"results": results
}
''', encoding="utf-8")
print("Graph-guided retriever now filters LinkedIn/resume/showcase chunks and boosts definition evidence.")