#!/usr/bin/env python3 """ Knowledge Graph Module This module provides efficient loading and querying of pre-computed knowledge graphs in Streamlit applications. It's designed to work with graphs generated by the build_knowledge_graphs.py script. Key features: - Fast graph loading with caching - Rich query interface for graph exploration - Integration with existing document processor workflow - Memory-efficient graph operations """ import pickle import json import numpy as np from pathlib import Path from typing import Dict, List, Any, Optional, Set, Tuple from datetime import datetime import streamlit as st import networkx as nx from sklearn.metrics.pairwise import cosine_similarity from app.core.config import get_config from app.core.logging import logger class KnowledgeGraphManager: """ Manages loading and querying of knowledge graphs for due diligence analysis. This class provides a clean interface for working with pre-computed knowledge graphs in Streamlit applications, with efficient caching and query capabilities. """ def __init__(self, store_name: str): """ Initialize the knowledge graph manager for a specific company. Args: store_name: The company store name (matches FAISS index name) """ self.store_name = store_name self.graph: Optional[nx.MultiDiGraph] = None self.metadata: Optional[Dict[str, Any]] = None self.entities: Optional[Dict[str, List[Dict]]] = None self.document_processor = None # Will be loaded on-demand for semantic search self._config = get_config() @st.cache_data(ttl=3600) # Cache for 1 hour def load_graph(_self) -> bool: """ Load the knowledge graph from disk with caching. Returns: bool: True if graph was loaded successfully, False otherwise """ try: graphs_dir = _self._config.paths['faiss_dir'] / 'knowledge_graphs' # Load main graph graph_file = graphs_dir / f"{_self.store_name}_knowledge_graph.pkl" if not graph_file.exists(): logger.warning(f"Knowledge graph not found: {graph_file}") return False with open(graph_file, 'rb') as f: _self.graph = pickle.load(f) # Load metadata metadata_file = graphs_dir / f"{_self.store_name}_graph_metadata.json" if metadata_file.exists(): with open(metadata_file, 'r') as f: _self.metadata = json.load(f) # Load entities entities_file = graphs_dir / f"{_self.store_name}_entities.json" if entities_file.exists(): with open(entities_file, 'r') as f: _self.entities = json.load(f) logger.info(f"Loaded knowledge graph for {_self.store_name}: " f"{len(_self.graph.nodes())} nodes, {len(_self.graph.edges())} edges") return True except Exception as e: logger.error(f"Failed to load knowledge graph for {_self.store_name}: {e}") return False def is_available(self) -> bool: """Check if knowledge graph is available and loaded""" return self.graph is not None and len(self.graph.nodes()) > 0 def get_summary_stats(self) -> Dict[str, Any]: """Get summary statistics about the knowledge graph""" if not self.is_available(): return {} stats = { 'num_entities': len(self.graph.nodes()), 'num_relationships': len(self.graph.edges()), 'entity_types': {}, 'relationship_types': {}, 'created_at': self.metadata.get('created_at') if self.metadata else None } # Count entity types for node in self.graph.nodes(): node_type = self.graph.nodes[node].get('type', 'unknown') stats['entity_types'][node_type] = stats['entity_types'].get(node_type, 0) + 1 # Count relationship types for _, _, edge_data in self.graph.edges(data=True): rel_type = edge_data.get('relationship', 'unknown') stats['relationship_types'][rel_type] = stats['relationship_types'].get(rel_type, 0) + 1 return stats def search_entities(self, query: str, entity_type: Optional[str] = None, limit: int = 10) -> List[Dict[str, Any]]: """ Search for entities by name or content. Args: query: Search query string entity_type: Filter by entity type (companies, people, etc.) limit: Maximum number of results Returns: List of matching entities with metadata """ if not self.is_available(): return [] query_lower = query.lower() results = [] for node in self.graph.nodes(): node_data = self.graph.nodes[node] node_name = node_data.get('name', '').lower() node_type = node_data.get('type', '') # Filter by type if specified if entity_type and node_type != entity_type: continue # Check if query matches name or context if query_lower in node_name: score = 1.0 if query_lower == node_name else 0.8 results.append({ 'node_id': node, 'name': node_data.get('name', ''), 'type': node_type, 'score': score, 'sources': node_data.get('sources', ''), 'document_type': node_data.get('document_type', 'unknown'), 'context_samples': node_data.get('context_samples', [])[:2] # Limit context }) # Sort by score and limit results results.sort(key=lambda x: x['score'], reverse=True) return results[:limit] def get_entity_relationships(self, entity_name: str) -> Dict[str, List[Dict[str, Any]]]: """ Get all relationships for a specific entity. Args: entity_name: Name of the entity to find relationships for Returns: Dictionary with 'incoming' and 'outgoing' relationship lists """ if not self.is_available(): return {'incoming': [], 'outgoing': []} # Find matching nodes matching_nodes = [] for node in self.graph.nodes(): if entity_name.lower() in self.graph.nodes[node].get('name', '').lower(): matching_nodes.append(node) if not matching_nodes: return {'incoming': [], 'outgoing': []} relationships = {'incoming': [], 'outgoing': []} for node in matching_nodes: # Outgoing relationships for _, target, edge_data in self.graph.out_edges(node, data=True): relationships['outgoing'].append({ 'target': self.graph.nodes[target].get('name', target), 'target_type': self.graph.nodes[target].get('type', 'unknown'), 'relationship': edge_data.get('relationship', 'unknown'), 'source_document': edge_data.get('source_document', ''), 'context': edge_data.get('context', '')[:200], # Truncate context 'confidence': edge_data.get('confidence', 0.0) }) # Incoming relationships for source, _, edge_data in self.graph.in_edges(node, data=True): relationships['incoming'].append({ 'source': self.graph.nodes[source].get('name', source), 'source_type': self.graph.nodes[source].get('type', 'unknown'), 'relationship': edge_data.get('relationship', 'unknown'), 'source_document': edge_data.get('source_document', ''), 'context': edge_data.get('context', '')[:200], # Truncate context 'confidence': edge_data.get('confidence', 0.0) }) return relationships def find_paths(self, source_entity: str, target_entity: str, max_length: int = 3) -> List[List[str]]: """ Find paths between two entities in the knowledge graph. Args: source_entity: Starting entity name target_entity: Target entity name max_length: Maximum path length to search Returns: List of paths (each path is a list of entity names) """ if not self.is_available(): return [] # Find matching nodes source_nodes = [n for n in self.graph.nodes() if source_entity.lower() in self.graph.nodes[n].get('name', '').lower()] target_nodes = [n for n in self.graph.nodes() if target_entity.lower() in self.graph.nodes[n].get('name', '').lower()] if not source_nodes or not target_nodes: return [] paths = [] for source_node in source_nodes: for target_node in target_nodes: if source_node == target_node: continue try: # Find all simple paths up to max_length simple_paths = list(nx.all_simple_paths( self.graph, source_node, target_node, cutoff=max_length )) # Convert node IDs to entity names for path in simple_paths[:5]: # Limit to 5 paths per pair entity_path = [self.graph.nodes[node].get('name', node) for node in path] paths.append(entity_path) except nx.NetworkXNoPath: continue return paths[:10] # Return max 10 paths total def get_central_entities(self, limit: int = 10) -> List[Dict[str, Any]]: """ Get the most central/important entities in the graph. Args: limit: Maximum number of entities to return Returns: List of entities with centrality scores """ if not self.is_available() or len(self.graph.nodes()) < 2: return [] try: # Calculate degree centrality centrality = nx.degree_centrality(self.graph) # Get top central entities top_entities = sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:limit] results = [] for node, score in top_entities: node_data = self.graph.nodes[node] results.append({ 'name': node_data.get('name', ''), 'type': node_data.get('type', 'unknown'), 'centrality_score': round(score, 3), 'num_connections': len(list(self.graph.neighbors(node))), 'sources': node_data.get('sources', '') }) return results except Exception as e: logger.error(f"Error calculating centrality: {e}") return [] def get_entity_clusters(self) -> List[List[str]]: """ Find clusters of related entities using community detection. Returns: List of clusters (each cluster is a list of entity names) """ if not self.is_available() or len(self.graph.nodes()) < 3: return [] try: # Convert to undirected graph for community detection undirected = self.graph.to_undirected() # Use simple connected components as clusters components = list(nx.connected_components(undirected)) clusters = [] for component in components: if len(component) > 1: # Only include clusters with multiple entities cluster_names = [self.graph.nodes[node].get('name', node) for node in component] clusters.append(cluster_names) # Sort clusters by size clusters.sort(key=len, reverse=True) return clusters[:5] # Return top 5 clusters except Exception as e: logger.error(f"Error finding clusters: {e}") return [] def export_graph_data(self) -> Dict[str, Any]: """ Export graph data for visualization or further analysis. Returns: Dictionary with nodes and edges data suitable for visualization """ if not self.is_available(): return {'nodes': [], 'edges': []} # Export nodes nodes = [] for node in self.graph.nodes(): node_data = self.graph.nodes[node] nodes.append({ 'id': node, 'name': node_data.get('name', ''), 'type': node_data.get('type', 'unknown'), 'sources': node_data.get('sources', ''), 'document_type': node_data.get('document_type', 'unknown') }) # Export edges edges = [] for source, target, edge_data in self.graph.edges(data=True): edges.append({ 'source': source, 'target': target, 'relationship': edge_data.get('relationship', 'unknown'), 'source_document': edge_data.get('source_document', ''), 'confidence': edge_data.get('confidence', 0.0) }) return { 'nodes': nodes, 'edges': edges, 'metadata': self.metadata or {} } def _load_document_processor(self): """Load document processor for semantic search capabilities""" if self.document_processor is None: try: from app.core.utils import create_document_processor self.document_processor = create_document_processor(store_name=self.store_name) if not self.document_processor.vector_store: logger.warning(f"No FAISS vector store available for {self.store_name}") self.document_processor = None except Exception as e: logger.error(f"Failed to load document processor for {self.store_name}: {e}") self.document_processor = None def semantic_search_entities(self, query: str, limit: int = 10, similarity_threshold: float = 0.3) -> List[Dict[str, Any]]: """ Perform semantic search on entities using FAISS embeddings. This method finds entities whose source contexts are semantically similar to the query, providing more intelligent search than simple text matching. Args: query: Natural language query limit: Maximum number of results similarity_threshold: Minimum similarity score to include Returns: List of entities with similarity scores and context """ if not self.is_available(): return [] # Load document processor if not already loaded self._load_document_processor() if not self.document_processor or not self.document_processor.vector_store: logger.warning("Semantic search not available - falling back to text search") return self.search_entities(query, limit=limit) try: # Perform semantic search on FAISS index relevant_docs = self.document_processor.vector_store.similarity_search_with_score( query, k=min(50, limit * 5) # Get more candidates for filtering ) # Map document chunks back to entities entity_matches = [] seen_entities = set() for doc, score in relevant_docs: if score < similarity_threshold: continue # Find entities that originated from this document chunk chunk_id = doc.metadata.get('chunk_id', '') doc_source = doc.metadata.get('source', '') # Search for entities that came from this chunk/document for node in self.graph.nodes(): node_data = self.graph.nodes[node] entity_sources = node_data.get('sources', '') # Check if entity came from this document if (doc_source and doc_source in entity_sources) or (chunk_id and chunk_id in str(node_data.get('context_samples', []))): entity_key = f"{node_data.get('name', '')}_{node_data.get('type', '')}" if entity_key not in seen_entities: seen_entities.add(entity_key) entity_matches.append({ 'node_id': node, 'name': node_data.get('name', ''), 'type': node_data.get('type', 'unknown'), 'similarity_score': 1.0 - score, # Convert distance to similarity 'sources': entity_sources, 'document_type': node_data.get('document_type', 'unknown'), 'context_samples': node_data.get('context_samples', [])[:2], 'matching_context': doc.page_content[:300] # Show relevant context }) if len(entity_matches) >= limit: break if len(entity_matches) >= limit: break # Sort by similarity score entity_matches.sort(key=lambda x: x['similarity_score'], reverse=True) return entity_matches[:limit] except Exception as e: logger.error(f"Semantic search failed: {e}") # Fallback to regular text search return self.search_entities(query, limit=limit) def find_related_entities_by_context(self, entity_name: str, limit: int = 5) -> List[Dict[str, Any]]: """ Find entities related to the given entity based on semantic similarity of their contexts. Args: entity_name: Name of the reference entity limit: Maximum number of related entities to return Returns: List of related entities with similarity scores """ if not self.is_available(): return [] # Find the reference entity reference_entities = [n for n in self.graph.nodes() if entity_name.lower() in self.graph.nodes[n].get('name', '').lower()] if not reference_entities: return [] # Load document processor self._load_document_processor() if not self.document_processor or not self.document_processor.vector_store: return [] try: # Get context samples from the reference entity reference_node = reference_entities[0] reference_data = self.graph.nodes[reference_node] context_samples = reference_data.get('context_samples', []) if not context_samples: return [] # Use the first context sample as a query query_context = context_samples[0][:500] # Limit context length # Find semantically similar contexts similar_docs = self.document_processor.vector_store.similarity_search_with_score( query_context, k=20 ) # Map back to entities related_entities = [] seen_entities = {reference_data.get('name', '')} for doc, score in similar_docs: doc_source = doc.metadata.get('source', '') # Find entities from this document for node in self.graph.nodes(): if node == reference_node: continue node_data = self.graph.nodes[node] entity_name_node = node_data.get('name', '') entity_sources = node_data.get('sources', '') if (entity_name_node not in seen_entities and doc_source and doc_source in entity_sources): seen_entities.add(entity_name_node) related_entities.append({ 'name': entity_name_node, 'type': node_data.get('type', 'unknown'), 'similarity_score': 1.0 - score, 'sources': entity_sources, 'context_samples': node_data.get('context_samples', [])[:1], 'relationship_reason': 'Semantic context similarity' }) if len(related_entities) >= limit: break if len(related_entities) >= limit: break # Sort by similarity related_entities.sort(key=lambda x: x['similarity_score'], reverse=True) return related_entities[:limit] except Exception as e: logger.error(f"Context-based entity search failed: {e}") return [] def semantic_path_search(self, query: str, max_paths: int = 5) -> List[Dict[str, Any]]: """ Find paths in the graph that are semantically relevant to a query. Args: query: Natural language description of what to find max_paths: Maximum number of paths to return Returns: List of paths with relevance scores """ if not self.is_available(): return [] # First, find entities semantically related to the query relevant_entities = self.semantic_search_entities(query, limit=10) if len(relevant_entities) < 2: return [] # Find interesting paths between the most relevant entities paths_found = [] for i, entity1 in enumerate(relevant_entities[:5]): # Limit to top 5 for performance for entity2 in relevant_entities[i+1:]: try: # Find paths between these entities paths = self.find_paths(entity1['name'], entity2['name'], max_length=3) for path in paths[:2]: # Limit paths per pair # Calculate path relevance based on entity similarity scores path_score = (entity1['similarity_score'] + entity2['similarity_score']) / 2 paths_found.append({ 'path': path, 'relevance_score': path_score, 'start_entity': entity1['name'], 'end_entity': entity2['name'], 'query_relevance': f"Related to: {query}", 'path_length': len(path) - 1 }) if len(paths_found) >= max_paths: break except Exception as e: logger.debug(f"Path finding failed between {entity1['name']} and {entity2['name']}: {e}") continue if len(paths_found) >= max_paths: break if len(paths_found) >= max_paths: break # Sort by relevance score paths_found.sort(key=lambda x: x['relevance_score'], reverse=True) return paths_found[:max_paths] @st.cache_data(ttl=3600) def get_available_knowledge_graphs() -> List[str]: """ Get list of available knowledge graphs. Returns: List of store names that have knowledge graphs available """ try: config = get_config() graphs_dir = config.paths['faiss_dir'] / 'knowledge_graphs' if not graphs_dir.exists(): return [] # Find all knowledge graph files graph_files = list(graphs_dir.glob("*_knowledge_graph.pkl")) store_names = [f.stem.replace('_knowledge_graph', '') for f in graph_files] return sorted(store_names) except Exception as e: logger.error(f"Error getting available knowledge graphs: {e}") return [] def create_knowledge_graph_manager(store_name: str) -> KnowledgeGraphManager: """ Factory function to create a knowledge graph manager. Args: store_name: Company store name Returns: Configured KnowledgeGraphManager instance """ return KnowledgeGraphManager(store_name)