|
|
|
|
|
""" |
|
|
Knowledge Graph Module |
|
|
|
|
|
This module provides efficient loading and querying of pre-computed knowledge graphs |
|
|
in Streamlit applications. It's designed to work with graphs generated by the |
|
|
build_knowledge_graphs.py script. |
|
|
|
|
|
Key features: |
|
|
- Fast graph loading with caching |
|
|
- Rich query interface for graph exploration |
|
|
- Integration with existing document processor workflow |
|
|
- Memory-efficient graph operations |
|
|
""" |
|
|
|
|
|
import pickle |
|
|
import json |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Any, Optional, Set, Tuple |
|
|
from datetime import datetime |
|
|
import streamlit as st |
|
|
|
|
|
import networkx as nx |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
from app.core.config import get_config |
|
|
from app.core.logging import logger |
|
|
|
|
|
class KnowledgeGraphManager: |
|
|
""" |
|
|
Manages loading and querying of knowledge graphs for due diligence analysis. |
|
|
|
|
|
This class provides a clean interface for working with pre-computed knowledge |
|
|
graphs in Streamlit applications, with efficient caching and query capabilities. |
|
|
""" |
|
|
|
|
|
def __init__(self, store_name: str): |
|
|
""" |
|
|
Initialize the knowledge graph manager for a specific company. |
|
|
|
|
|
Args: |
|
|
store_name: The company store name (matches FAISS index name) |
|
|
""" |
|
|
self.store_name = store_name |
|
|
self.graph: Optional[nx.MultiDiGraph] = None |
|
|
self.metadata: Optional[Dict[str, Any]] = None |
|
|
self.entities: Optional[Dict[str, List[Dict]]] = None |
|
|
self.document_processor = None |
|
|
self._config = get_config() |
|
|
|
|
|
@st.cache_data(ttl=3600) |
|
|
def load_graph(_self) -> bool: |
|
|
""" |
|
|
Load the knowledge graph from disk with caching. |
|
|
|
|
|
Returns: |
|
|
bool: True if graph was loaded successfully, False otherwise |
|
|
""" |
|
|
try: |
|
|
graphs_dir = _self._config.paths['faiss_dir'] / 'knowledge_graphs' |
|
|
|
|
|
|
|
|
graph_file = graphs_dir / f"{_self.store_name}_knowledge_graph.pkl" |
|
|
if not graph_file.exists(): |
|
|
logger.warning(f"Knowledge graph not found: {graph_file}") |
|
|
return False |
|
|
|
|
|
with open(graph_file, 'rb') as f: |
|
|
_self.graph = pickle.load(f) |
|
|
|
|
|
|
|
|
metadata_file = graphs_dir / f"{_self.store_name}_graph_metadata.json" |
|
|
if metadata_file.exists(): |
|
|
with open(metadata_file, 'r') as f: |
|
|
_self.metadata = json.load(f) |
|
|
|
|
|
|
|
|
entities_file = graphs_dir / f"{_self.store_name}_entities.json" |
|
|
if entities_file.exists(): |
|
|
with open(entities_file, 'r') as f: |
|
|
_self.entities = json.load(f) |
|
|
|
|
|
logger.info(f"Loaded knowledge graph for {_self.store_name}: " |
|
|
f"{len(_self.graph.nodes())} nodes, {len(_self.graph.edges())} edges") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load knowledge graph for {_self.store_name}: {e}") |
|
|
return False |
|
|
|
|
|
def is_available(self) -> bool: |
|
|
"""Check if knowledge graph is available and loaded""" |
|
|
return self.graph is not None and len(self.graph.nodes()) > 0 |
|
|
|
|
|
def get_summary_stats(self) -> Dict[str, Any]: |
|
|
"""Get summary statistics about the knowledge graph""" |
|
|
if not self.is_available(): |
|
|
return {} |
|
|
|
|
|
stats = { |
|
|
'num_entities': len(self.graph.nodes()), |
|
|
'num_relationships': len(self.graph.edges()), |
|
|
'entity_types': {}, |
|
|
'relationship_types': {}, |
|
|
'created_at': self.metadata.get('created_at') if self.metadata else None |
|
|
} |
|
|
|
|
|
|
|
|
for node in self.graph.nodes(): |
|
|
node_type = self.graph.nodes[node].get('type', 'unknown') |
|
|
stats['entity_types'][node_type] = stats['entity_types'].get(node_type, 0) + 1 |
|
|
|
|
|
|
|
|
for _, _, edge_data in self.graph.edges(data=True): |
|
|
rel_type = edge_data.get('relationship', 'unknown') |
|
|
stats['relationship_types'][rel_type] = stats['relationship_types'].get(rel_type, 0) + 1 |
|
|
|
|
|
return stats |
|
|
|
|
|
def search_entities(self, query: str, entity_type: Optional[str] = None, limit: int = 10) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Search for entities by name or content. |
|
|
|
|
|
Args: |
|
|
query: Search query string |
|
|
entity_type: Filter by entity type (companies, people, etc.) |
|
|
limit: Maximum number of results |
|
|
|
|
|
Returns: |
|
|
List of matching entities with metadata |
|
|
""" |
|
|
if not self.is_available(): |
|
|
return [] |
|
|
|
|
|
query_lower = query.lower() |
|
|
results = [] |
|
|
|
|
|
for node in self.graph.nodes(): |
|
|
node_data = self.graph.nodes[node] |
|
|
node_name = node_data.get('name', '').lower() |
|
|
node_type = node_data.get('type', '') |
|
|
|
|
|
|
|
|
if entity_type and node_type != entity_type: |
|
|
continue |
|
|
|
|
|
|
|
|
if query_lower in node_name: |
|
|
score = 1.0 if query_lower == node_name else 0.8 |
|
|
|
|
|
results.append({ |
|
|
'node_id': node, |
|
|
'name': node_data.get('name', ''), |
|
|
'type': node_type, |
|
|
'score': score, |
|
|
'sources': node_data.get('sources', ''), |
|
|
'document_type': node_data.get('document_type', 'unknown'), |
|
|
'context_samples': node_data.get('context_samples', [])[:2] |
|
|
}) |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x['score'], reverse=True) |
|
|
return results[:limit] |
|
|
|
|
|
def get_entity_relationships(self, entity_name: str) -> Dict[str, List[Dict[str, Any]]]: |
|
|
""" |
|
|
Get all relationships for a specific entity. |
|
|
|
|
|
Args: |
|
|
entity_name: Name of the entity to find relationships for |
|
|
|
|
|
Returns: |
|
|
Dictionary with 'incoming' and 'outgoing' relationship lists |
|
|
""" |
|
|
if not self.is_available(): |
|
|
return {'incoming': [], 'outgoing': []} |
|
|
|
|
|
|
|
|
matching_nodes = [] |
|
|
for node in self.graph.nodes(): |
|
|
if entity_name.lower() in self.graph.nodes[node].get('name', '').lower(): |
|
|
matching_nodes.append(node) |
|
|
|
|
|
if not matching_nodes: |
|
|
return {'incoming': [], 'outgoing': []} |
|
|
|
|
|
relationships = {'incoming': [], 'outgoing': []} |
|
|
|
|
|
for node in matching_nodes: |
|
|
|
|
|
for _, target, edge_data in self.graph.out_edges(node, data=True): |
|
|
relationships['outgoing'].append({ |
|
|
'target': self.graph.nodes[target].get('name', target), |
|
|
'target_type': self.graph.nodes[target].get('type', 'unknown'), |
|
|
'relationship': edge_data.get('relationship', 'unknown'), |
|
|
'source_document': edge_data.get('source_document', ''), |
|
|
'context': edge_data.get('context', '')[:200], |
|
|
'confidence': edge_data.get('confidence', 0.0) |
|
|
}) |
|
|
|
|
|
|
|
|
for source, _, edge_data in self.graph.in_edges(node, data=True): |
|
|
relationships['incoming'].append({ |
|
|
'source': self.graph.nodes[source].get('name', source), |
|
|
'source_type': self.graph.nodes[source].get('type', 'unknown'), |
|
|
'relationship': edge_data.get('relationship', 'unknown'), |
|
|
'source_document': edge_data.get('source_document', ''), |
|
|
'context': edge_data.get('context', '')[:200], |
|
|
'confidence': edge_data.get('confidence', 0.0) |
|
|
}) |
|
|
|
|
|
return relationships |
|
|
|
|
|
def find_paths(self, source_entity: str, target_entity: str, max_length: int = 3) -> List[List[str]]: |
|
|
""" |
|
|
Find paths between two entities in the knowledge graph. |
|
|
|
|
|
Args: |
|
|
source_entity: Starting entity name |
|
|
target_entity: Target entity name |
|
|
max_length: Maximum path length to search |
|
|
|
|
|
Returns: |
|
|
List of paths (each path is a list of entity names) |
|
|
""" |
|
|
if not self.is_available(): |
|
|
return [] |
|
|
|
|
|
|
|
|
source_nodes = [n for n in self.graph.nodes() |
|
|
if source_entity.lower() in self.graph.nodes[n].get('name', '').lower()] |
|
|
target_nodes = [n for n in self.graph.nodes() |
|
|
if target_entity.lower() in self.graph.nodes[n].get('name', '').lower()] |
|
|
|
|
|
if not source_nodes or not target_nodes: |
|
|
return [] |
|
|
|
|
|
paths = [] |
|
|
for source_node in source_nodes: |
|
|
for target_node in target_nodes: |
|
|
if source_node == target_node: |
|
|
continue |
|
|
|
|
|
try: |
|
|
|
|
|
simple_paths = list(nx.all_simple_paths( |
|
|
self.graph, source_node, target_node, cutoff=max_length |
|
|
)) |
|
|
|
|
|
|
|
|
for path in simple_paths[:5]: |
|
|
entity_path = [self.graph.nodes[node].get('name', node) for node in path] |
|
|
paths.append(entity_path) |
|
|
|
|
|
except nx.NetworkXNoPath: |
|
|
continue |
|
|
|
|
|
return paths[:10] |
|
|
|
|
|
def get_central_entities(self, limit: int = 10) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Get the most central/important entities in the graph. |
|
|
|
|
|
Args: |
|
|
limit: Maximum number of entities to return |
|
|
|
|
|
Returns: |
|
|
List of entities with centrality scores |
|
|
""" |
|
|
if not self.is_available() or len(self.graph.nodes()) < 2: |
|
|
return [] |
|
|
|
|
|
try: |
|
|
|
|
|
centrality = nx.degree_centrality(self.graph) |
|
|
|
|
|
|
|
|
top_entities = sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:limit] |
|
|
|
|
|
results = [] |
|
|
for node, score in top_entities: |
|
|
node_data = self.graph.nodes[node] |
|
|
results.append({ |
|
|
'name': node_data.get('name', ''), |
|
|
'type': node_data.get('type', 'unknown'), |
|
|
'centrality_score': round(score, 3), |
|
|
'num_connections': len(list(self.graph.neighbors(node))), |
|
|
'sources': node_data.get('sources', '') |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error calculating centrality: {e}") |
|
|
return [] |
|
|
|
|
|
def get_entity_clusters(self) -> List[List[str]]: |
|
|
""" |
|
|
Find clusters of related entities using community detection. |
|
|
|
|
|
Returns: |
|
|
List of clusters (each cluster is a list of entity names) |
|
|
""" |
|
|
if not self.is_available() or len(self.graph.nodes()) < 3: |
|
|
return [] |
|
|
|
|
|
try: |
|
|
|
|
|
undirected = self.graph.to_undirected() |
|
|
|
|
|
|
|
|
components = list(nx.connected_components(undirected)) |
|
|
|
|
|
clusters = [] |
|
|
for component in components: |
|
|
if len(component) > 1: |
|
|
cluster_names = [self.graph.nodes[node].get('name', node) for node in component] |
|
|
clusters.append(cluster_names) |
|
|
|
|
|
|
|
|
clusters.sort(key=len, reverse=True) |
|
|
return clusters[:5] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error finding clusters: {e}") |
|
|
return [] |
|
|
|
|
|
def export_graph_data(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Export graph data for visualization or further analysis. |
|
|
|
|
|
Returns: |
|
|
Dictionary with nodes and edges data suitable for visualization |
|
|
""" |
|
|
if not self.is_available(): |
|
|
return {'nodes': [], 'edges': []} |
|
|
|
|
|
|
|
|
nodes = [] |
|
|
for node in self.graph.nodes(): |
|
|
node_data = self.graph.nodes[node] |
|
|
nodes.append({ |
|
|
'id': node, |
|
|
'name': node_data.get('name', ''), |
|
|
'type': node_data.get('type', 'unknown'), |
|
|
'sources': node_data.get('sources', ''), |
|
|
'document_type': node_data.get('document_type', 'unknown') |
|
|
}) |
|
|
|
|
|
|
|
|
edges = [] |
|
|
for source, target, edge_data in self.graph.edges(data=True): |
|
|
edges.append({ |
|
|
'source': source, |
|
|
'target': target, |
|
|
'relationship': edge_data.get('relationship', 'unknown'), |
|
|
'source_document': edge_data.get('source_document', ''), |
|
|
'confidence': edge_data.get('confidence', 0.0) |
|
|
}) |
|
|
|
|
|
return { |
|
|
'nodes': nodes, |
|
|
'edges': edges, |
|
|
'metadata': self.metadata or {} |
|
|
} |
|
|
|
|
|
def _load_document_processor(self): |
|
|
"""Load document processor for semantic search capabilities""" |
|
|
if self.document_processor is None: |
|
|
try: |
|
|
from app.core.utils import create_document_processor |
|
|
self.document_processor = create_document_processor(store_name=self.store_name) |
|
|
if not self.document_processor.vector_store: |
|
|
logger.warning(f"No FAISS vector store available for {self.store_name}") |
|
|
self.document_processor = None |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load document processor for {self.store_name}: {e}") |
|
|
self.document_processor = None |
|
|
|
|
|
def semantic_search_entities(self, query: str, limit: int = 10, similarity_threshold: float = 0.3) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Perform semantic search on entities using FAISS embeddings. |
|
|
|
|
|
This method finds entities whose source contexts are semantically similar |
|
|
to the query, providing more intelligent search than simple text matching. |
|
|
|
|
|
Args: |
|
|
query: Natural language query |
|
|
limit: Maximum number of results |
|
|
similarity_threshold: Minimum similarity score to include |
|
|
|
|
|
Returns: |
|
|
List of entities with similarity scores and context |
|
|
""" |
|
|
if not self.is_available(): |
|
|
return [] |
|
|
|
|
|
|
|
|
self._load_document_processor() |
|
|
if not self.document_processor or not self.document_processor.vector_store: |
|
|
logger.warning("Semantic search not available - falling back to text search") |
|
|
return self.search_entities(query, limit=limit) |
|
|
|
|
|
try: |
|
|
|
|
|
relevant_docs = self.document_processor.vector_store.similarity_search_with_score( |
|
|
query, k=min(50, limit * 5) |
|
|
) |
|
|
|
|
|
|
|
|
entity_matches = [] |
|
|
seen_entities = set() |
|
|
|
|
|
for doc, score in relevant_docs: |
|
|
if score < similarity_threshold: |
|
|
continue |
|
|
|
|
|
|
|
|
chunk_id = doc.metadata.get('chunk_id', '') |
|
|
doc_source = doc.metadata.get('source', '') |
|
|
|
|
|
|
|
|
for node in self.graph.nodes(): |
|
|
node_data = self.graph.nodes[node] |
|
|
entity_sources = node_data.get('sources', '') |
|
|
|
|
|
|
|
|
if (doc_source and doc_source in entity_sources) or (chunk_id and chunk_id in str(node_data.get('context_samples', []))): |
|
|
entity_key = f"{node_data.get('name', '')}_{node_data.get('type', '')}" |
|
|
|
|
|
if entity_key not in seen_entities: |
|
|
seen_entities.add(entity_key) |
|
|
entity_matches.append({ |
|
|
'node_id': node, |
|
|
'name': node_data.get('name', ''), |
|
|
'type': node_data.get('type', 'unknown'), |
|
|
'similarity_score': 1.0 - score, |
|
|
'sources': entity_sources, |
|
|
'document_type': node_data.get('document_type', 'unknown'), |
|
|
'context_samples': node_data.get('context_samples', [])[:2], |
|
|
'matching_context': doc.page_content[:300] |
|
|
}) |
|
|
|
|
|
if len(entity_matches) >= limit: |
|
|
break |
|
|
|
|
|
if len(entity_matches) >= limit: |
|
|
break |
|
|
|
|
|
|
|
|
entity_matches.sort(key=lambda x: x['similarity_score'], reverse=True) |
|
|
return entity_matches[:limit] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Semantic search failed: {e}") |
|
|
|
|
|
return self.search_entities(query, limit=limit) |
|
|
|
|
|
def find_related_entities_by_context(self, entity_name: str, limit: int = 5) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Find entities related to the given entity based on semantic similarity of their contexts. |
|
|
|
|
|
Args: |
|
|
entity_name: Name of the reference entity |
|
|
limit: Maximum number of related entities to return |
|
|
|
|
|
Returns: |
|
|
List of related entities with similarity scores |
|
|
""" |
|
|
if not self.is_available(): |
|
|
return [] |
|
|
|
|
|
|
|
|
reference_entities = [n for n in self.graph.nodes() |
|
|
if entity_name.lower() in self.graph.nodes[n].get('name', '').lower()] |
|
|
|
|
|
if not reference_entities: |
|
|
return [] |
|
|
|
|
|
|
|
|
self._load_document_processor() |
|
|
if not self.document_processor or not self.document_processor.vector_store: |
|
|
return [] |
|
|
|
|
|
try: |
|
|
|
|
|
reference_node = reference_entities[0] |
|
|
reference_data = self.graph.nodes[reference_node] |
|
|
context_samples = reference_data.get('context_samples', []) |
|
|
|
|
|
if not context_samples: |
|
|
return [] |
|
|
|
|
|
|
|
|
query_context = context_samples[0][:500] |
|
|
|
|
|
|
|
|
similar_docs = self.document_processor.vector_store.similarity_search_with_score( |
|
|
query_context, k=20 |
|
|
) |
|
|
|
|
|
|
|
|
related_entities = [] |
|
|
seen_entities = {reference_data.get('name', '')} |
|
|
|
|
|
for doc, score in similar_docs: |
|
|
doc_source = doc.metadata.get('source', '') |
|
|
|
|
|
|
|
|
for node in self.graph.nodes(): |
|
|
if node == reference_node: |
|
|
continue |
|
|
|
|
|
node_data = self.graph.nodes[node] |
|
|
entity_name_node = node_data.get('name', '') |
|
|
entity_sources = node_data.get('sources', '') |
|
|
|
|
|
if (entity_name_node not in seen_entities and |
|
|
doc_source and doc_source in entity_sources): |
|
|
|
|
|
seen_entities.add(entity_name_node) |
|
|
related_entities.append({ |
|
|
'name': entity_name_node, |
|
|
'type': node_data.get('type', 'unknown'), |
|
|
'similarity_score': 1.0 - score, |
|
|
'sources': entity_sources, |
|
|
'context_samples': node_data.get('context_samples', [])[:1], |
|
|
'relationship_reason': 'Semantic context similarity' |
|
|
}) |
|
|
|
|
|
if len(related_entities) >= limit: |
|
|
break |
|
|
|
|
|
if len(related_entities) >= limit: |
|
|
break |
|
|
|
|
|
|
|
|
related_entities.sort(key=lambda x: x['similarity_score'], reverse=True) |
|
|
return related_entities[:limit] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Context-based entity search failed: {e}") |
|
|
return [] |
|
|
|
|
|
def semantic_path_search(self, query: str, max_paths: int = 5) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Find paths in the graph that are semantically relevant to a query. |
|
|
|
|
|
Args: |
|
|
query: Natural language description of what to find |
|
|
max_paths: Maximum number of paths to return |
|
|
|
|
|
Returns: |
|
|
List of paths with relevance scores |
|
|
""" |
|
|
if not self.is_available(): |
|
|
return [] |
|
|
|
|
|
|
|
|
relevant_entities = self.semantic_search_entities(query, limit=10) |
|
|
|
|
|
if len(relevant_entities) < 2: |
|
|
return [] |
|
|
|
|
|
|
|
|
paths_found = [] |
|
|
|
|
|
for i, entity1 in enumerate(relevant_entities[:5]): |
|
|
for entity2 in relevant_entities[i+1:]: |
|
|
try: |
|
|
|
|
|
paths = self.find_paths(entity1['name'], entity2['name'], max_length=3) |
|
|
|
|
|
for path in paths[:2]: |
|
|
|
|
|
path_score = (entity1['similarity_score'] + entity2['similarity_score']) / 2 |
|
|
|
|
|
paths_found.append({ |
|
|
'path': path, |
|
|
'relevance_score': path_score, |
|
|
'start_entity': entity1['name'], |
|
|
'end_entity': entity2['name'], |
|
|
'query_relevance': f"Related to: {query}", |
|
|
'path_length': len(path) - 1 |
|
|
}) |
|
|
|
|
|
if len(paths_found) >= max_paths: |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
logger.debug(f"Path finding failed between {entity1['name']} and {entity2['name']}: {e}") |
|
|
continue |
|
|
|
|
|
if len(paths_found) >= max_paths: |
|
|
break |
|
|
|
|
|
if len(paths_found) >= max_paths: |
|
|
break |
|
|
|
|
|
|
|
|
paths_found.sort(key=lambda x: x['relevance_score'], reverse=True) |
|
|
return paths_found[:max_paths] |
|
|
|
|
|
@st.cache_data(ttl=3600) |
|
|
def get_available_knowledge_graphs() -> List[str]: |
|
|
""" |
|
|
Get list of available knowledge graphs. |
|
|
|
|
|
Returns: |
|
|
List of store names that have knowledge graphs available |
|
|
""" |
|
|
try: |
|
|
config = get_config() |
|
|
graphs_dir = config.paths['faiss_dir'] / 'knowledge_graphs' |
|
|
|
|
|
if not graphs_dir.exists(): |
|
|
return [] |
|
|
|
|
|
|
|
|
graph_files = list(graphs_dir.glob("*_knowledge_graph.pkl")) |
|
|
store_names = [f.stem.replace('_knowledge_graph', '') for f in graph_files] |
|
|
|
|
|
return sorted(store_names) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error getting available knowledge graphs: {e}") |
|
|
return [] |
|
|
|
|
|
def create_knowledge_graph_manager(store_name: str) -> KnowledgeGraphManager: |
|
|
""" |
|
|
Factory function to create a knowledge graph manager. |
|
|
|
|
|
Args: |
|
|
store_name: Company store name |
|
|
|
|
|
Returns: |
|
|
Configured KnowledgeGraphManager instance |
|
|
""" |
|
|
return KnowledgeGraphManager(store_name) |
|
|
|