Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| path = Path("app/graph/graph_context_service.py") | |
| path.write_text(r''' | |
| import re | |
| from typing import Dict, Any, List, Optional | |
| from app.graph.graph_storage import read_document_graph | |
| STOPWORDS = { | |
| "what", "is", "are", "the", "a", "an", "of", "to", "and", "or", | |
| "in", "on", "for", "with", "from", "by", "how", "why", "explain", | |
| "define", "meaning", "does", "do", "it", "this", "that" | |
| } | |
| def tokenize_query(query: str) -> List[str]: | |
| words = re.findall(r"[a-zA-Z0-9_]+", (query or "").lower()) | |
| return [ | |
| word for word in words | |
| if word not in STOPWORDS and len(word) > 1 | |
| ] | |
| def tokenize_entity_name(name: str) -> List[str]: | |
| return re.findall(r"[a-zA-Z0-9_]+", (name or "").lower()) | |
| def entity_relevance_score(entity, query_terms: List[str]) -> float: | |
| if not query_terms: | |
| return 0.0 | |
| name_lower = entity.name.lower() | |
| entity_id_lower = entity.entity_id.lower() | |
| name_tokens = tokenize_entity_name(entity.name) | |
| entity_id_tokens = tokenize_entity_name(entity.entity_id.replace("_", " ")) | |
| score = 0.0 | |
| for term in query_terms: | |
| # Exact entity match | |
| if term == name_lower or term == entity_id_lower: | |
| score += 10.0 | |
| continue | |
| # Token-level match. This prevents rag matching paragraph. | |
| if term in name_tokens: | |
| score += 6.0 | |
| continue | |
| if term in entity_id_tokens: | |
| score += 5.0 | |
| continue | |
| # Only allow substring match for longer terms. | |
| # Example: "retrieval" can match "retrieval-augmented generation". | |
| # But short acronyms like rag/api/llm should not match inside random words. | |
| if len(term) >= 4 and term in name_lower: | |
| score += 2.0 | |
| if score > 0: | |
| score += min(entity.mention_count, 10) * 0.15 | |
| return score | |
| def build_graph_context_for_query( | |
| document_id: Optional[str], | |
| query: str, | |
| limit: int = 8 | |
| ) -> Dict[str, Any]: | |
| if not document_id: | |
| return { | |
| "graph_available": False, | |
| "reason": "No document_id provided.", | |
| "matched_entities": [], | |
| "matched_relations": [], | |
| "context_text": "" | |
| } | |
| graph = read_document_graph(document_id) | |
| if graph is None: | |
| return { | |
| "graph_available": False, | |
| "reason": "Graph not built for this document.", | |
| "matched_entities": [], | |
| "matched_relations": [], | |
| "context_text": "" | |
| } | |
| query_terms = tokenize_query(query) | |
| scored_entities = [] | |
| for entity in graph.entities: | |
| score = entity_relevance_score(entity, query_terms) | |
| if score > 0: | |
| scored_entities.append((score, entity)) | |
| scored_entities.sort(key=lambda item: item[0], reverse=True) | |
| matched_entities = [ | |
| entity for score, entity in scored_entities[:limit] | |
| ] | |
| matched_entity_ids = { | |
| entity.entity_id for entity in matched_entities | |
| } | |
| matched_relations = [] | |
| for relation in graph.relations: | |
| if ( | |
| relation.source_entity_id in matched_entity_ids | |
| or relation.target_entity_id in matched_entity_ids | |
| ): | |
| matched_relations.append(relation) | |
| matched_relations = sorted( | |
| matched_relations, | |
| key=lambda relation: relation.weight, | |
| reverse=True | |
| )[:limit] | |
| context_text = build_graph_context_text( | |
| matched_entities=matched_entities, | |
| matched_relations=matched_relations | |
| ) | |
| return { | |
| "graph_available": True, | |
| "document_id": document_id, | |
| "source_file_name": graph.source_file_name, | |
| "query_terms": query_terms, | |
| "matched_entities": [ | |
| { | |
| "entity_id": entity.entity_id, | |
| "name": entity.name, | |
| "entity_type": entity.entity_type, | |
| "mention_count": entity.mention_count, | |
| "pages": entity.pages[:10], | |
| "chunk_ids": entity.chunk_ids[:10] | |
| } | |
| for entity in matched_entities | |
| ], | |
| "matched_relations": [ | |
| { | |
| "relation_id": relation.relation_id, | |
| "source": relation.source_name, | |
| "relation_type": relation.relation_type, | |
| "target": relation.target_name, | |
| "weight": relation.weight, | |
| "pages": relation.pages[:10], | |
| "chunk_ids": relation.chunk_ids[:10] | |
| } | |
| for relation in matched_relations | |
| ], | |
| "context_text": context_text | |
| } | |
| def build_graph_context_text( | |
| matched_entities, | |
| matched_relations | |
| ) -> str: | |
| lines = [] | |
| if matched_entities: | |
| lines.append("Relevant graph entities:") | |
| for entity in matched_entities: | |
| pages = ", ".join(str(page) for page in entity.pages[:5]) | |
| lines.append( | |
| f"- {entity.name} ({entity.entity_type}), mentions={entity.mention_count}, pages={pages}" | |
| ) | |
| if matched_relations: | |
| lines.append("") | |
| lines.append("Relevant graph relations:") | |
| for relation in matched_relations: | |
| lines.append( | |
| f"- {relation.source_name} --{relation.relation_type}--> {relation.target_name} " | |
| f"(weight={relation.weight})" | |
| ) | |
| return "\n".join(lines).strip() | |
| ''', encoding="utf-8") | |
| print("Fixed graph query matching. Short acronyms like RAG will no longer match inside words like Paragraph.") | |