from pathlib import Path path = Path("app/graph/graph_context_service.py") path.write_text(r''' import re from typing import Dict, Any, List, Optional from app.graph.graph_storage import read_document_graph STOPWORDS = { "what", "is", "are", "the", "a", "an", "of", "to", "and", "or", "in", "on", "for", "with", "from", "by", "how", "why", "explain", "define", "meaning", "does", "do", "it", "this", "that" } def tokenize_query(query: str) -> List[str]: words = re.findall(r"[a-zA-Z0-9_]+", (query or "").lower()) return [ word for word in words if word not in STOPWORDS and len(word) > 1 ] def tokenize_entity_name(name: str) -> List[str]: return re.findall(r"[a-zA-Z0-9_]+", (name or "").lower()) def entity_relevance_score(entity, query_terms: List[str]) -> float: if not query_terms: return 0.0 name_lower = entity.name.lower() entity_id_lower = entity.entity_id.lower() name_tokens = tokenize_entity_name(entity.name) entity_id_tokens = tokenize_entity_name(entity.entity_id.replace("_", " ")) score = 0.0 for term in query_terms: # Exact entity match if term == name_lower or term == entity_id_lower: score += 10.0 continue # Token-level match. This prevents rag matching paragraph. if term in name_tokens: score += 6.0 continue if term in entity_id_tokens: score += 5.0 continue # Only allow substring match for longer terms. # Example: "retrieval" can match "retrieval-augmented generation". # But short acronyms like rag/api/llm should not match inside random words. if len(term) >= 4 and term in name_lower: score += 2.0 if score > 0: score += min(entity.mention_count, 10) * 0.15 return score def build_graph_context_for_query( document_id: Optional[str], query: str, limit: int = 8 ) -> Dict[str, Any]: if not document_id: return { "graph_available": False, "reason": "No document_id provided.", "matched_entities": [], "matched_relations": [], "context_text": "" } graph = read_document_graph(document_id) if graph is None: return { "graph_available": False, "reason": "Graph not built for this document.", "matched_entities": [], "matched_relations": [], "context_text": "" } query_terms = tokenize_query(query) scored_entities = [] for entity in graph.entities: score = entity_relevance_score(entity, query_terms) if score > 0: scored_entities.append((score, entity)) scored_entities.sort(key=lambda item: item[0], reverse=True) matched_entities = [ entity for score, entity in scored_entities[:limit] ] matched_entity_ids = { entity.entity_id for entity in matched_entities } matched_relations = [] for relation in graph.relations: if ( relation.source_entity_id in matched_entity_ids or relation.target_entity_id in matched_entity_ids ): matched_relations.append(relation) matched_relations = sorted( matched_relations, key=lambda relation: relation.weight, reverse=True )[:limit] context_text = build_graph_context_text( matched_entities=matched_entities, matched_relations=matched_relations ) return { "graph_available": True, "document_id": document_id, "source_file_name": graph.source_file_name, "query_terms": query_terms, "matched_entities": [ { "entity_id": entity.entity_id, "name": entity.name, "entity_type": entity.entity_type, "mention_count": entity.mention_count, "pages": entity.pages[:10], "chunk_ids": entity.chunk_ids[:10] } for entity in matched_entities ], "matched_relations": [ { "relation_id": relation.relation_id, "source": relation.source_name, "relation_type": relation.relation_type, "target": relation.target_name, "weight": relation.weight, "pages": relation.pages[:10], "chunk_ids": relation.chunk_ids[:10] } for relation in matched_relations ], "context_text": context_text } def build_graph_context_text( matched_entities, matched_relations ) -> str: lines = [] if matched_entities: lines.append("Relevant graph entities:") for entity in matched_entities: pages = ", ".join(str(page) for page in entity.pages[:5]) lines.append( f"- {entity.name} ({entity.entity_type}), mentions={entity.mention_count}, pages={pages}" ) if matched_relations: lines.append("") lines.append("Relevant graph relations:") for relation in matched_relations: lines.append( f"- {relation.source_name} --{relation.relation_type}--> {relation.target_name} " f"(weight={relation.weight})" ) return "\n".join(lines).strip() ''', encoding="utf-8") print("Fixed graph query matching. Short acronyms like RAG will no longer match inside words like Paragraph.")