| """ |
| Knowledge Graph Implementation |
| |
| Graph-based knowledge representation using networkx. |
| """ |
|
|
| from typing import List, Dict, Optional, Any, Set, Tuple |
| import networkx as nx |
| import numpy as np |
| from datetime import datetime |
| import json |
| from pathlib import Path |
|
|
|
|
| class KnowledgeGraph: |
| """Graph-based knowledge representation with entities and relationships.""" |
|
|
| def __init__( |
| self, |
| max_nodes: int = 10000, |
| max_edges: int = 50000, |
| ): |
| """ |
| Initialize the knowledge graph. |
| |
| Args: |
| max_nodes: Maximum number of nodes |
| max_edges: Maximum number of edges |
| """ |
| self.max_nodes = max_nodes |
| self.max_edges = max_edges |
| self.graph = nx.MultiDiGraph() |
| self._node_attributes: Dict[str, Dict[str, Any]] = {} |
| self._edge_attributes: Dict[Tuple[str, str], Dict[str, Any]] = {} |
|
|
| def add_entity( |
| self, |
| entity_id: str, |
| entity_type: str, |
| properties: Optional[Dict[str, Any]] = None, |
| ) -> bool: |
| """ |
| Add an entity to the knowledge graph. |
| |
| Args: |
| entity_id: Unique identifier for the entity |
| entity_type: Type of entity (e.g., 'person', 'concept', 'code') |
| properties: Additional properties |
| |
| Returns: |
| True if added, False if limit reached |
| """ |
| if self.graph.number_of_nodes() >= self.max_nodes: |
| return False |
|
|
| if entity_id not in self.graph: |
| self.graph.add_node(entity_id, type=entity_type) |
|
|
| self._node_attributes[entity_id] = { |
| "type": entity_type, |
| "created_at": datetime.now().isoformat(), |
| "properties": properties or {}, |
| } |
|
|
| return True |
|
|
| def add_relationship( |
| self, |
| source_id: str, |
| target_id: str, |
| relationship_type: str, |
| properties: Optional[Dict[str, Any]] = None, |
| ) -> bool: |
| """ |
| Add a relationship between entities. |
| |
| Args: |
| source_id: Source entity ID |
| target_id: Target entity ID |
| relationship_type: Type of relationship |
| properties: Additional properties |
| |
| Returns: |
| True if added, False if limit reached or entities don't exist |
| """ |
| if self.graph.number_of_edges() >= self.max_edges: |
| return False |
|
|
| |
| if source_id not in self.graph: |
| self.add_entity(source_id, "unknown") |
| if target_id not in self.graph: |
| self.add_entity(target_id, "unknown") |
|
|
| self.graph.add_edge(source_id, target_id, type=relationship_type) |
|
|
| edge_key = (source_id, target_id) |
| self._edge_attributes[edge_key] = { |
| "type": relationship_type, |
| "created_at": datetime.now().isoformat(), |
| "properties": properties or {}, |
| } |
|
|
| return True |
|
|
| def get_entity(self, entity_id: str) -> Optional[Dict[str, Any]]: |
| """Get entity information.""" |
| if entity_id not in self.graph: |
| return None |
|
|
| return { |
| "id": entity_id, |
| **self._node_attributes.get(entity_id, {}), |
| } |
|
|
| def get_relationships( |
| self, |
| entity_id: str, |
| relationship_type: Optional[str] = None, |
| ) -> List[Dict[str, Any]]: |
| """Get relationships for an entity.""" |
| if entity_id not in self.graph: |
| return [] |
|
|
| relationships = [] |
| for source, target, data in self.graph.edges(data=True): |
| if source == entity_id or target == entity_id: |
| rel_type = data.get("type", "unknown") |
| if relationship_type and rel_type != relationship_type: |
| continue |
|
|
| relationships.append({ |
| "source": source, |
| "target": target, |
| "type": rel_type, |
| }) |
|
|
| return relationships |
|
|
| def find_similar_entities( |
| self, |
| entity_id: str, |
| max_results: int = 5, |
| ) -> List[Tuple[str, float]]: |
| """ |
| Find similar entities using graph-based similarity. |
| |
| Args: |
| entity_id: Entity to find similar |
| max_results: Maximum number of results |
| |
| Returns: |
| List of (entity_id, similarity_score) tuples |
| """ |
| if entity_id not in self.graph: |
| return [] |
|
|
| |
| neighbors = set(self.graph.neighbors(entity_id)) |
| scores = [] |
|
|
| for node in self.graph.nodes(): |
| if node == entity_id: |
| continue |
|
|
| node_neighbors = set(self.graph.neighbors(node)) |
| common = len(neighbors & node_neighbors) |
|
|
| if common > 0: |
| |
| union = len(neighbors | node_neighbors) |
| score = common / union if union > 0 else 0 |
| scores.append((node, score)) |
|
|
| scores.sort(key=lambda x: -x[1]) |
| return scores[:max_results] |
|
|
| def search_entities( |
| self, |
| entity_type: Optional[str] = None, |
| property_filter: Optional[Dict[str, Any]] = None, |
| ) -> List[str]: |
| """ |
| Search for entities. |
| |
| Args: |
| entity_type: Filter by entity type |
| property_filter: Filter by properties |
| |
| Returns: |
| List of matching entity IDs |
| """ |
| results = [] |
|
|
| for node in self.graph.nodes(): |
| attrs = self._node_attributes.get(node, {}) |
|
|
| |
| if entity_type and attrs.get("type") != entity_type: |
| continue |
|
|
| |
| if property_filter: |
| props = attrs.get("properties", {}) |
| if not all(props.get(k) == v for k, v in property_filter.items()): |
| continue |
|
|
| results.append(node) |
|
|
| return results |
|
|
| def get_subgraph( |
| self, |
| entity_ids: List[str], |
| depth: int = 1, |
| ) -> nx.MultiDiGraph: |
| """ |
| Get a subgraph around specified entities. |
| |
| Args: |
| entity_ids: Center entities |
| depth: How many hops to include |
| |
| Returns: |
| Subgraph |
| """ |
| nodes = set(entity_ids) |
|
|
| for _ in range(depth): |
| for entity in list(nodes): |
| nodes.update(self.graph.neighbors(entity)) |
|
|
| return self.graph.subgraph(nodes).copy() |
|
|
| def export_json(self, filepath: str) -> None: |
| """Export graph to JSON.""" |
| data = { |
| "nodes": [ |
| { |
| "id": node, |
| **self._node_attributes.get(node, {}), |
| } |
| for node in self.graph.nodes() |
| ], |
| "edges": [ |
| { |
| "source": source, |
| "target": target, |
| "type": data.get("type", "unknown"), |
| } |
| for source, target, data in self.graph.edges(data=True) |
| ], |
| } |
|
|
| Path(filepath).write_text(json.dumps(data, indent=2)) |
|
|
| def import_json(self, filepath: str) -> None: |
| """Import graph from JSON.""" |
| data = json.loads(Path(filepath).read_text()) |
|
|
| for node_data in data.get("nodes", []): |
| node_id = node_data.pop("id") |
| self.add_entity(node_id, node_data.get("type", "unknown"), node_data.get("properties")) |
|
|
| for edge_data in data.get("edges", []): |
| self.add_relationship( |
| edge_data["source"], |
| edge_data["target"], |
| edge_data.get("type", "unknown"), |
| ) |
|
|
| def get_stats(self) -> Dict[str, Any]: |
| """Get graph statistics.""" |
| return { |
| "num_nodes": self.graph.number_of_nodes(), |
| "num_edges": self.graph.number_of_edges(), |
| "num_node_types": len(set( |
| attrs.get("type") for attrs in self._node_attributes.values() |
| )), |
| "num_edge_types": len(set( |
| data.get("type") for _, _, data in self.graph.edges(data=True) |
| )), |
| } |
|
|
| def __repr__(self) -> str: |
| stats = self.get_stats() |
| return f"KnowledgeGraph(nodes={stats['num_nodes']}, edges={stats['num_edges']})" |