dd-poc / app /core /knowledge_graph.py
Juan Salas
Refactored code
12f0afd
#!/usr/bin/env python3
"""
Knowledge Graph Module
This module provides efficient loading and querying of pre-computed knowledge graphs
in Streamlit applications. It's designed to work with graphs generated by the
build_knowledge_graphs.py script.
Key features:
- Fast graph loading with caching
- Rich query interface for graph exploration
- Integration with existing document processor workflow
- Memory-efficient graph operations
"""
import pickle
import json
import numpy as np
from pathlib import Path
from typing import Dict, List, Any, Optional, Set, Tuple
from datetime import datetime
import streamlit as st
import networkx as nx
from sklearn.metrics.pairwise import cosine_similarity
from app.core.config import get_config
from app.core.logging import logger
class KnowledgeGraphManager:
"""
Manages loading and querying of knowledge graphs for due diligence analysis.
This class provides a clean interface for working with pre-computed knowledge
graphs in Streamlit applications, with efficient caching and query capabilities.
"""
def __init__(self, store_name: str):
"""
Initialize the knowledge graph manager for a specific company.
Args:
store_name: The company store name (matches FAISS index name)
"""
self.store_name = store_name
self.graph: Optional[nx.MultiDiGraph] = None
self.metadata: Optional[Dict[str, Any]] = None
self.entities: Optional[Dict[str, List[Dict]]] = None
self.document_processor = None # Will be loaded on-demand for semantic search
self._config = get_config()
@st.cache_data(ttl=3600) # Cache for 1 hour
def load_graph(_self) -> bool:
"""
Load the knowledge graph from disk with caching.
Returns:
bool: True if graph was loaded successfully, False otherwise
"""
try:
graphs_dir = _self._config.paths['faiss_dir'] / 'knowledge_graphs'
# Load main graph
graph_file = graphs_dir / f"{_self.store_name}_knowledge_graph.pkl"
if not graph_file.exists():
logger.warning(f"Knowledge graph not found: {graph_file}")
return False
with open(graph_file, 'rb') as f:
_self.graph = pickle.load(f)
# Load metadata
metadata_file = graphs_dir / f"{_self.store_name}_graph_metadata.json"
if metadata_file.exists():
with open(metadata_file, 'r') as f:
_self.metadata = json.load(f)
# Load entities
entities_file = graphs_dir / f"{_self.store_name}_entities.json"
if entities_file.exists():
with open(entities_file, 'r') as f:
_self.entities = json.load(f)
logger.info(f"Loaded knowledge graph for {_self.store_name}: "
f"{len(_self.graph.nodes())} nodes, {len(_self.graph.edges())} edges")
return True
except Exception as e:
logger.error(f"Failed to load knowledge graph for {_self.store_name}: {e}")
return False
def is_available(self) -> bool:
"""Check if knowledge graph is available and loaded"""
return self.graph is not None and len(self.graph.nodes()) > 0
def get_summary_stats(self) -> Dict[str, Any]:
"""Get summary statistics about the knowledge graph"""
if not self.is_available():
return {}
stats = {
'num_entities': len(self.graph.nodes()),
'num_relationships': len(self.graph.edges()),
'entity_types': {},
'relationship_types': {},
'created_at': self.metadata.get('created_at') if self.metadata else None
}
# Count entity types
for node in self.graph.nodes():
node_type = self.graph.nodes[node].get('type', 'unknown')
stats['entity_types'][node_type] = stats['entity_types'].get(node_type, 0) + 1
# Count relationship types
for _, _, edge_data in self.graph.edges(data=True):
rel_type = edge_data.get('relationship', 'unknown')
stats['relationship_types'][rel_type] = stats['relationship_types'].get(rel_type, 0) + 1
return stats
def search_entities(self, query: str, entity_type: Optional[str] = None, limit: int = 10) -> List[Dict[str, Any]]:
"""
Search for entities by name or content.
Args:
query: Search query string
entity_type: Filter by entity type (companies, people, etc.)
limit: Maximum number of results
Returns:
List of matching entities with metadata
"""
if not self.is_available():
return []
query_lower = query.lower()
results = []
for node in self.graph.nodes():
node_data = self.graph.nodes[node]
node_name = node_data.get('name', '').lower()
node_type = node_data.get('type', '')
# Filter by type if specified
if entity_type and node_type != entity_type:
continue
# Check if query matches name or context
if query_lower in node_name:
score = 1.0 if query_lower == node_name else 0.8
results.append({
'node_id': node,
'name': node_data.get('name', ''),
'type': node_type,
'score': score,
'sources': node_data.get('sources', ''),
'document_type': node_data.get('document_type', 'unknown'),
'context_samples': node_data.get('context_samples', [])[:2] # Limit context
})
# Sort by score and limit results
results.sort(key=lambda x: x['score'], reverse=True)
return results[:limit]
def get_entity_relationships(self, entity_name: str) -> Dict[str, List[Dict[str, Any]]]:
"""
Get all relationships for a specific entity.
Args:
entity_name: Name of the entity to find relationships for
Returns:
Dictionary with 'incoming' and 'outgoing' relationship lists
"""
if not self.is_available():
return {'incoming': [], 'outgoing': []}
# Find matching nodes
matching_nodes = []
for node in self.graph.nodes():
if entity_name.lower() in self.graph.nodes[node].get('name', '').lower():
matching_nodes.append(node)
if not matching_nodes:
return {'incoming': [], 'outgoing': []}
relationships = {'incoming': [], 'outgoing': []}
for node in matching_nodes:
# Outgoing relationships
for _, target, edge_data in self.graph.out_edges(node, data=True):
relationships['outgoing'].append({
'target': self.graph.nodes[target].get('name', target),
'target_type': self.graph.nodes[target].get('type', 'unknown'),
'relationship': edge_data.get('relationship', 'unknown'),
'source_document': edge_data.get('source_document', ''),
'context': edge_data.get('context', '')[:200], # Truncate context
'confidence': edge_data.get('confidence', 0.0)
})
# Incoming relationships
for source, _, edge_data in self.graph.in_edges(node, data=True):
relationships['incoming'].append({
'source': self.graph.nodes[source].get('name', source),
'source_type': self.graph.nodes[source].get('type', 'unknown'),
'relationship': edge_data.get('relationship', 'unknown'),
'source_document': edge_data.get('source_document', ''),
'context': edge_data.get('context', '')[:200], # Truncate context
'confidence': edge_data.get('confidence', 0.0)
})
return relationships
def find_paths(self, source_entity: str, target_entity: str, max_length: int = 3) -> List[List[str]]:
"""
Find paths between two entities in the knowledge graph.
Args:
source_entity: Starting entity name
target_entity: Target entity name
max_length: Maximum path length to search
Returns:
List of paths (each path is a list of entity names)
"""
if not self.is_available():
return []
# Find matching nodes
source_nodes = [n for n in self.graph.nodes()
if source_entity.lower() in self.graph.nodes[n].get('name', '').lower()]
target_nodes = [n for n in self.graph.nodes()
if target_entity.lower() in self.graph.nodes[n].get('name', '').lower()]
if not source_nodes or not target_nodes:
return []
paths = []
for source_node in source_nodes:
for target_node in target_nodes:
if source_node == target_node:
continue
try:
# Find all simple paths up to max_length
simple_paths = list(nx.all_simple_paths(
self.graph, source_node, target_node, cutoff=max_length
))
# Convert node IDs to entity names
for path in simple_paths[:5]: # Limit to 5 paths per pair
entity_path = [self.graph.nodes[node].get('name', node) for node in path]
paths.append(entity_path)
except nx.NetworkXNoPath:
continue
return paths[:10] # Return max 10 paths total
def get_central_entities(self, limit: int = 10) -> List[Dict[str, Any]]:
"""
Get the most central/important entities in the graph.
Args:
limit: Maximum number of entities to return
Returns:
List of entities with centrality scores
"""
if not self.is_available() or len(self.graph.nodes()) < 2:
return []
try:
# Calculate degree centrality
centrality = nx.degree_centrality(self.graph)
# Get top central entities
top_entities = sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:limit]
results = []
for node, score in top_entities:
node_data = self.graph.nodes[node]
results.append({
'name': node_data.get('name', ''),
'type': node_data.get('type', 'unknown'),
'centrality_score': round(score, 3),
'num_connections': len(list(self.graph.neighbors(node))),
'sources': node_data.get('sources', '')
})
return results
except Exception as e:
logger.error(f"Error calculating centrality: {e}")
return []
def get_entity_clusters(self) -> List[List[str]]:
"""
Find clusters of related entities using community detection.
Returns:
List of clusters (each cluster is a list of entity names)
"""
if not self.is_available() or len(self.graph.nodes()) < 3:
return []
try:
# Convert to undirected graph for community detection
undirected = self.graph.to_undirected()
# Use simple connected components as clusters
components = list(nx.connected_components(undirected))
clusters = []
for component in components:
if len(component) > 1: # Only include clusters with multiple entities
cluster_names = [self.graph.nodes[node].get('name', node) for node in component]
clusters.append(cluster_names)
# Sort clusters by size
clusters.sort(key=len, reverse=True)
return clusters[:5] # Return top 5 clusters
except Exception as e:
logger.error(f"Error finding clusters: {e}")
return []
def export_graph_data(self) -> Dict[str, Any]:
"""
Export graph data for visualization or further analysis.
Returns:
Dictionary with nodes and edges data suitable for visualization
"""
if not self.is_available():
return {'nodes': [], 'edges': []}
# Export nodes
nodes = []
for node in self.graph.nodes():
node_data = self.graph.nodes[node]
nodes.append({
'id': node,
'name': node_data.get('name', ''),
'type': node_data.get('type', 'unknown'),
'sources': node_data.get('sources', ''),
'document_type': node_data.get('document_type', 'unknown')
})
# Export edges
edges = []
for source, target, edge_data in self.graph.edges(data=True):
edges.append({
'source': source,
'target': target,
'relationship': edge_data.get('relationship', 'unknown'),
'source_document': edge_data.get('source_document', ''),
'confidence': edge_data.get('confidence', 0.0)
})
return {
'nodes': nodes,
'edges': edges,
'metadata': self.metadata or {}
}
def _load_document_processor(self):
"""Load document processor for semantic search capabilities"""
if self.document_processor is None:
try:
from app.core.utils import create_document_processor
self.document_processor = create_document_processor(store_name=self.store_name)
if not self.document_processor.vector_store:
logger.warning(f"No FAISS vector store available for {self.store_name}")
self.document_processor = None
except Exception as e:
logger.error(f"Failed to load document processor for {self.store_name}: {e}")
self.document_processor = None
def semantic_search_entities(self, query: str, limit: int = 10, similarity_threshold: float = 0.3) -> List[Dict[str, Any]]:
"""
Perform semantic search on entities using FAISS embeddings.
This method finds entities whose source contexts are semantically similar
to the query, providing more intelligent search than simple text matching.
Args:
query: Natural language query
limit: Maximum number of results
similarity_threshold: Minimum similarity score to include
Returns:
List of entities with similarity scores and context
"""
if not self.is_available():
return []
# Load document processor if not already loaded
self._load_document_processor()
if not self.document_processor or not self.document_processor.vector_store:
logger.warning("Semantic search not available - falling back to text search")
return self.search_entities(query, limit=limit)
try:
# Perform semantic search on FAISS index
relevant_docs = self.document_processor.vector_store.similarity_search_with_score(
query, k=min(50, limit * 5) # Get more candidates for filtering
)
# Map document chunks back to entities
entity_matches = []
seen_entities = set()
for doc, score in relevant_docs:
if score < similarity_threshold:
continue
# Find entities that originated from this document chunk
chunk_id = doc.metadata.get('chunk_id', '')
doc_source = doc.metadata.get('source', '')
# Search for entities that came from this chunk/document
for node in self.graph.nodes():
node_data = self.graph.nodes[node]
entity_sources = node_data.get('sources', '')
# Check if entity came from this document
if (doc_source and doc_source in entity_sources) or (chunk_id and chunk_id in str(node_data.get('context_samples', []))):
entity_key = f"{node_data.get('name', '')}_{node_data.get('type', '')}"
if entity_key not in seen_entities:
seen_entities.add(entity_key)
entity_matches.append({
'node_id': node,
'name': node_data.get('name', ''),
'type': node_data.get('type', 'unknown'),
'similarity_score': 1.0 - score, # Convert distance to similarity
'sources': entity_sources,
'document_type': node_data.get('document_type', 'unknown'),
'context_samples': node_data.get('context_samples', [])[:2],
'matching_context': doc.page_content[:300] # Show relevant context
})
if len(entity_matches) >= limit:
break
if len(entity_matches) >= limit:
break
# Sort by similarity score
entity_matches.sort(key=lambda x: x['similarity_score'], reverse=True)
return entity_matches[:limit]
except Exception as e:
logger.error(f"Semantic search failed: {e}")
# Fallback to regular text search
return self.search_entities(query, limit=limit)
def find_related_entities_by_context(self, entity_name: str, limit: int = 5) -> List[Dict[str, Any]]:
"""
Find entities related to the given entity based on semantic similarity of their contexts.
Args:
entity_name: Name of the reference entity
limit: Maximum number of related entities to return
Returns:
List of related entities with similarity scores
"""
if not self.is_available():
return []
# Find the reference entity
reference_entities = [n for n in self.graph.nodes()
if entity_name.lower() in self.graph.nodes[n].get('name', '').lower()]
if not reference_entities:
return []
# Load document processor
self._load_document_processor()
if not self.document_processor or not self.document_processor.vector_store:
return []
try:
# Get context samples from the reference entity
reference_node = reference_entities[0]
reference_data = self.graph.nodes[reference_node]
context_samples = reference_data.get('context_samples', [])
if not context_samples:
return []
# Use the first context sample as a query
query_context = context_samples[0][:500] # Limit context length
# Find semantically similar contexts
similar_docs = self.document_processor.vector_store.similarity_search_with_score(
query_context, k=20
)
# Map back to entities
related_entities = []
seen_entities = {reference_data.get('name', '')}
for doc, score in similar_docs:
doc_source = doc.metadata.get('source', '')
# Find entities from this document
for node in self.graph.nodes():
if node == reference_node:
continue
node_data = self.graph.nodes[node]
entity_name_node = node_data.get('name', '')
entity_sources = node_data.get('sources', '')
if (entity_name_node not in seen_entities and
doc_source and doc_source in entity_sources):
seen_entities.add(entity_name_node)
related_entities.append({
'name': entity_name_node,
'type': node_data.get('type', 'unknown'),
'similarity_score': 1.0 - score,
'sources': entity_sources,
'context_samples': node_data.get('context_samples', [])[:1],
'relationship_reason': 'Semantic context similarity'
})
if len(related_entities) >= limit:
break
if len(related_entities) >= limit:
break
# Sort by similarity
related_entities.sort(key=lambda x: x['similarity_score'], reverse=True)
return related_entities[:limit]
except Exception as e:
logger.error(f"Context-based entity search failed: {e}")
return []
def semantic_path_search(self, query: str, max_paths: int = 5) -> List[Dict[str, Any]]:
"""
Find paths in the graph that are semantically relevant to a query.
Args:
query: Natural language description of what to find
max_paths: Maximum number of paths to return
Returns:
List of paths with relevance scores
"""
if not self.is_available():
return []
# First, find entities semantically related to the query
relevant_entities = self.semantic_search_entities(query, limit=10)
if len(relevant_entities) < 2:
return []
# Find interesting paths between the most relevant entities
paths_found = []
for i, entity1 in enumerate(relevant_entities[:5]): # Limit to top 5 for performance
for entity2 in relevant_entities[i+1:]:
try:
# Find paths between these entities
paths = self.find_paths(entity1['name'], entity2['name'], max_length=3)
for path in paths[:2]: # Limit paths per pair
# Calculate path relevance based on entity similarity scores
path_score = (entity1['similarity_score'] + entity2['similarity_score']) / 2
paths_found.append({
'path': path,
'relevance_score': path_score,
'start_entity': entity1['name'],
'end_entity': entity2['name'],
'query_relevance': f"Related to: {query}",
'path_length': len(path) - 1
})
if len(paths_found) >= max_paths:
break
except Exception as e:
logger.debug(f"Path finding failed between {entity1['name']} and {entity2['name']}: {e}")
continue
if len(paths_found) >= max_paths:
break
if len(paths_found) >= max_paths:
break
# Sort by relevance score
paths_found.sort(key=lambda x: x['relevance_score'], reverse=True)
return paths_found[:max_paths]
@st.cache_data(ttl=3600)
def get_available_knowledge_graphs() -> List[str]:
"""
Get list of available knowledge graphs.
Returns:
List of store names that have knowledge graphs available
"""
try:
config = get_config()
graphs_dir = config.paths['faiss_dir'] / 'knowledge_graphs'
if not graphs_dir.exists():
return []
# Find all knowledge graph files
graph_files = list(graphs_dir.glob("*_knowledge_graph.pkl"))
store_names = [f.stem.replace('_knowledge_graph', '') for f in graph_files]
return sorted(store_names)
except Exception as e:
logger.error(f"Error getting available knowledge graphs: {e}")
return []
def create_knowledge_graph_manager(store_name: str) -> KnowledgeGraphManager:
"""
Factory function to create a knowledge graph manager.
Args:
store_name: Company store name
Returns:
Configured KnowledgeGraphManager instance
"""
return KnowledgeGraphManager(store_name)