|
|
import networkx as nx |
|
|
import json |
|
|
import os |
|
|
import asyncio |
|
|
import nest_asyncio |
|
|
import tqdm |
|
|
|
|
|
import os.path |
|
|
import tempfile |
|
|
import subprocess |
|
|
from typing import List, Optional, Dict |
|
|
import logging |
|
|
import urllib.parse |
|
|
|
|
|
from .ModelService import create_model_service |
|
|
from .Node import Node, DirectoryNode, FileNode, ChunkNode, EntityNode |
|
|
from .CodeParser import CodeParser |
|
|
from .EntityExtractor import HybridEntityExtractor |
|
|
from .CodeIndex import CodeIndex |
|
|
from .utils.logger_utils import setup_logger |
|
|
from .utils.parsing_utils import read_directory_files_recursively, get_language_from_filename |
|
|
from .utils.path_utils import prepare_input_path, build_entity_alias_map, resolve_entity_call |
|
|
from .EntityChunkMapper import EntityChunkMapper |
|
|
|
|
|
LOGGER_NAME = 'REPO_KNOWLEDGE_GRAPH_LOGGER' |
|
|
|
|
|
MODEL_SERVICE_TYPES = ['openai', 'sentence-transformers'] |
|
|
|
|
|
|
|
|
|
|
|
class RepoKnowledgeGraph: |
|
|
""" |
|
|
RepoKnowledgeGraph builds a knowledge graph of a code repository. |
|
|
It parses source files, extracts code entities and relationships, and organizes them |
|
|
into a directed acyclic graph (DAG) with additional semantic edges. |
|
|
|
|
|
Use `from_path()` or `load_graph_from_file()` to create instances. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
""" |
|
|
Private constructor. Use from_path() or load_graph_from_file() instead. |
|
|
""" |
|
|
raise RuntimeError( |
|
|
"Cannot instantiate RepoKnowledgeGraph directly. " |
|
|
"Use RepoKnowledgeGraph.from_path() or RepoKnowledgeGraph.load_graph_from_file() instead." |
|
|
) |
|
|
|
|
|
def _initialize(self, model_service_kwargs: dict, code_index_kwargs: Optional[dict] = None): |
|
|
"""Internal initialization method.""" |
|
|
setup_logger(LOGGER_NAME) |
|
|
self.logger = logging.getLogger(LOGGER_NAME) |
|
|
self.logger.info('Initializing RepoKnowledgeGraph instance.') |
|
|
self.code_parser = CodeParser() |
|
|
|
|
|
|
|
|
index_type = (code_index_kwargs or {}).get('index_type', 'hybrid') |
|
|
skip_embedder = index_type == 'keyword-only' |
|
|
if skip_embedder: |
|
|
self.logger.info('Using keyword-only index, skipping embedder initialization') |
|
|
|
|
|
self.model_service = create_model_service(skip_embedder=skip_embedder, **model_service_kwargs) |
|
|
self.entities = {} |
|
|
self.graph = nx.DiGraph() |
|
|
self.knowledge_graph = nx.DiGraph() |
|
|
self.code_index = None |
|
|
self.entity_extractor = HybridEntityExtractor() |
|
|
|
|
|
def __iter__(self): |
|
|
|
|
|
return (node_data['data'] for _, node_data in self.graph.nodes(data=True)) |
|
|
|
|
|
def __getitem__(self, node_id): |
|
|
return self.graph.nodes[node_id]['data'] |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def from_path(cls, path: str, skip_dirs: Optional[list] = None, index_nodes: bool = True, describe_nodes=False, |
|
|
extract_entities: bool = False, model_service_kwargs: Optional[dict] = None, code_index_kwargs: Optional[dict] = None): |
|
|
if skip_dirs is None: |
|
|
skip_dirs = [] |
|
|
if model_service_kwargs is None: |
|
|
model_service_kwargs = {} |
|
|
""" |
|
|
Alternative constructor to build a RepoKnowledgeGraph from a path, with options to skip directories |
|
|
and control entity extraction and node description. |
|
|
|
|
|
Args: |
|
|
path (str): Path to the root of the code repository. |
|
|
skip_dirs (list): List of directory names to skip. |
|
|
index_nodes (bool): Whether to build a code index. |
|
|
describe_nodes (bool): Whether to generate descriptions for code chunks. |
|
|
extract_entities (bool): Whether to extract entities from code. |
|
|
|
|
|
Returns: |
|
|
RepoKnowledgeGraph: The constructed knowledge graph. |
|
|
""" |
|
|
instance = cls.__new__(cls) |
|
|
instance._initialize(model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs) |
|
|
|
|
|
instance.logger.info(f"Preparing to build knowledge graph from path: {path}") |
|
|
|
|
|
prepared_path = prepare_input_path(path) |
|
|
instance.logger.debug(f"Prepared input path: {prepared_path}") |
|
|
|
|
|
|
|
|
try: |
|
|
loop = asyncio.get_running_loop() |
|
|
except RuntimeError: |
|
|
loop = None |
|
|
|
|
|
if loop and loop.is_running(): |
|
|
instance.logger.debug("Detected running event loop, applying nest_asyncio.") |
|
|
nest_asyncio.apply() |
|
|
task = instance._initial_parse_path_async(prepared_path, skip_dirs=skip_dirs, index_nodes=index_nodes, |
|
|
describe_nodes=describe_nodes, extract_entities=extract_entities) |
|
|
loop.run_until_complete(task) |
|
|
else: |
|
|
instance.logger.debug("No running event loop, using asyncio.run.") |
|
|
asyncio.run(instance._initial_parse_path_async(prepared_path, skip_dirs=skip_dirs, index_nodes=index_nodes, |
|
|
describe_nodes=describe_nodes, |
|
|
extract_entities=extract_entities)) |
|
|
|
|
|
instance.logger.info("Parsing files and building initial nodes...") |
|
|
instance.logger.info("Initial parse and node creation complete. Building relationships between nodes...") |
|
|
instance._build_relationships() |
|
|
|
|
|
if index_nodes: |
|
|
instance.logger.info("Building code index for all nodes in the graph...") |
|
|
instance.code_index = CodeIndex(list(instance), model_service=instance.model_service, **(code_index_kwargs or {})) |
|
|
|
|
|
instance.logger.info("Knowledge graph construction from path completed successfully.") |
|
|
return instance |
|
|
|
|
|
@classmethod |
|
|
def from_repo( |
|
|
cls, |
|
|
repo_url: str, |
|
|
skip_dirs: Optional[list] = None, |
|
|
index_nodes: bool = True, |
|
|
describe_nodes: bool = False, |
|
|
extract_entities: bool = False, |
|
|
model_service_kwargs: Optional[dict] = None, |
|
|
code_index_kwargs: Optional[dict]=None, |
|
|
github_token: Optional[str] = None, |
|
|
allow_unauthenticated_clone: bool = True, |
|
|
): |
|
|
""" |
|
|
Alternative constructor to build a RepoKnowledgeGraph from a remote git repository URL. |
|
|
|
|
|
Args: |
|
|
repo_url (str): Git repository URL (SSH or HTTPS). |
|
|
skip_dirs (list): List of directory names to skip. |
|
|
index_nodes (bool): Whether to build a code index. |
|
|
describe_nodes (bool): Whether to generate descriptions for code chunks. |
|
|
extract_entities (bool): Whether to extract entities from code. |
|
|
github_token (str, optional): Personal access token to access private GitHub repos. |
|
|
If not provided, the method will look for the `GITHUB_OAUTH_TOKEN` environment variable. |
|
|
allow_unauthenticated_clone (bool): If True, attempt to clone without a token when none is provided. |
|
|
If False, raise an error when no token is available. |
|
|
|
|
|
Returns: |
|
|
RepoKnowledgeGraph: The constructed knowledge graph. |
|
|
""" |
|
|
if skip_dirs is None: |
|
|
skip_dirs = [] |
|
|
if model_service_kwargs is None: |
|
|
model_service_kwargs = {} |
|
|
|
|
|
instance = cls.__new__(cls) |
|
|
instance._initialize(model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs) |
|
|
|
|
|
instance.logger.info(f"Starting knowledge graph build from remote repository: {repo_url}") |
|
|
|
|
|
|
|
|
token = github_token or os.environ.get('GITHUB_OAUTH_TOKEN') |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
|
clone_url = repo_url |
|
|
try: |
|
|
if repo_url.startswith('git@'): |
|
|
|
|
|
clone_url = repo_url.replace(':', '/').split('git@')[-1] |
|
|
clone_url = f'https://{clone_url}' |
|
|
|
|
|
if token and clone_url.startswith('https://'): |
|
|
encoded_token = urllib.parse.quote(token, safe='') |
|
|
clone_url = clone_url.replace('https://', f'https://{encoded_token}@') |
|
|
elif not token and not allow_unauthenticated_clone: |
|
|
raise ValueError( |
|
|
"GitHub token not provided and unauthenticated clone is disabled. " |
|
|
"Set allow_unauthenticated_clone=True or provide a token." |
|
|
) |
|
|
|
|
|
instance.logger.debug(f"Running git clone: {clone_url} -> {tmpdirname}") |
|
|
subprocess.run(['git', 'clone', clone_url, tmpdirname], check=True) |
|
|
|
|
|
except Exception as e: |
|
|
instance.logger.error(f"Failed to clone repository {repo_url} using URL {clone_url}: {e}") |
|
|
raise |
|
|
|
|
|
instance.logger.info(f"Repository successfully cloned to: {tmpdirname}") |
|
|
|
|
|
return cls.from_path( |
|
|
tmpdirname, |
|
|
skip_dirs=skip_dirs, |
|
|
index_nodes=index_nodes, |
|
|
describe_nodes=describe_nodes, |
|
|
extract_entities=extract_entities, |
|
|
model_service_kwargs=model_service_kwargs, |
|
|
code_index_kwargs=code_index_kwargs |
|
|
) |
|
|
|
|
|
async def _initial_parse_path_async(self, path: str, skip_dirs: list, index_nodes=True, describe_nodes=True, |
|
|
extract_entities: bool = True): |
|
|
self.logger.info(f"Beginning async parsing of repository at path: {path}") |
|
|
""" |
|
|
Orchestrates the parsing and graph construction process: |
|
|
1. Reads files and splits into chunks. |
|
|
2. Extracts entities and relationships. |
|
|
3. Builds chunk, file, directory, and root nodes. |
|
|
4. Aggregates entity information. |
|
|
|
|
|
Args: |
|
|
path (str): Root path to parse. |
|
|
skip_dirs (list): Directories to skip. |
|
|
index_nodes (bool): Whether to build code index. |
|
|
describe_nodes (bool): Whether to generate descriptions. |
|
|
extract_entities (bool): Whether to extract entities. |
|
|
""" |
|
|
|
|
|
|
|
|
level1_node_contents = read_directory_files_recursively( |
|
|
path, skip_dirs=skip_dirs, |
|
|
skip_pattern=r"(?:\.log$|\.json$|(?:^|/)(?:\.git|\.idea|__pycache__|\.cache)(?:/|$)|(?:^|/)(?:changelog|ChangeLog)(?:\.[a-z0-9]+)?$|\.cache$)" |
|
|
) |
|
|
self.logger.debug(f"Found {len(level1_node_contents)} files to process.") |
|
|
self.logger.info("Chunk nodes creation step started.") |
|
|
chunk_info = await self._create_chunk_nodes( |
|
|
level1_node_contents, extract_entities, describe_nodes, index_nodes, root_path=path |
|
|
) |
|
|
self.logger.info("Chunk nodes creation step finished.") |
|
|
self.logger.info("File nodes creation step started.") |
|
|
file_info = self._create_file_nodes( |
|
|
chunk_info, level1_node_contents |
|
|
) |
|
|
self.logger.info("File nodes creation step finished.") |
|
|
self.logger.info("Directory nodes creation step started.") |
|
|
dir_agg = self._create_directory_nodes( |
|
|
file_info |
|
|
) |
|
|
self.logger.info("Directory nodes creation step finished.") |
|
|
self.logger.info("Aggregating all nodes to root node.") |
|
|
self._aggregate_to_root(dir_agg) |
|
|
self.logger.info("Async parse and node aggregation fully complete.") |
|
|
|
|
|
async def _create_chunk_nodes(self, level1_node_contents, extract_entities, describe_nodes, index_nodes, root_path=None): |
|
|
self.logger.info(f"Starting chunk node creation for {len(level1_node_contents)} files.") |
|
|
accepted_extensions = {'.py', '.c', '.cpp', '.h', '.hpp', '.java', '.js', '.ts', '.jsx', '.tsx', '.rs', '.html'} |
|
|
chunk_info = {} |
|
|
entity_mapper = EntityChunkMapper() |
|
|
total_chunks = 0 |
|
|
|
|
|
|
|
|
for file_path in tqdm.tqdm(level1_node_contents, desc="Processing files for chunk nodes"): |
|
|
self.logger.debug(f"Processing file for chunk nodes: {file_path}") |
|
|
full_path = os.path.normpath(file_path) |
|
|
parts = full_path.split(os.sep) |
|
|
_, ext = os.path.splitext(file_path) |
|
|
is_code_file = ext.lower() in accepted_extensions |
|
|
|
|
|
self.logger.debug(f"Parsing file: {file_path}") |
|
|
|
|
|
|
|
|
parsed_content = self.code_parser.parse(file_name=file_path, file_content=level1_node_contents[file_path]) |
|
|
self.logger.debug(f"Parsed {len(parsed_content)} chunks from file: {file_path}") |
|
|
total_chunks += len(parsed_content) |
|
|
|
|
|
|
|
|
if extract_entities and is_code_file: |
|
|
self.logger.debug(f"Extracting entities from code file: {file_path}") |
|
|
try: |
|
|
|
|
|
extraction_file_path = os.path.join(root_path, file_path) if root_path else file_path |
|
|
|
|
|
file_declared_entities, file_called_entities = self.entity_extractor.extract_entities( |
|
|
code=level1_node_contents[file_path], file_name=extraction_file_path) |
|
|
self.logger.debug(f"Extracted {len(file_declared_entities)} declared and {len(file_called_entities)} called entities from file: {file_path}") |
|
|
|
|
|
chunk_declared_map, chunk_called_map = entity_mapper.map_entities_to_chunks( |
|
|
file_declared_entities, file_called_entities, parsed_content, file_name=file_path) |
|
|
self.logger.debug(f"Mapped entities to {len(parsed_content)} chunks for file: {file_path}") |
|
|
except Exception as e: |
|
|
self.logger.error(f"Error extracting entities from {file_path}: {e}") |
|
|
file_declared_entities, file_called_entities = [], [] |
|
|
chunk_declared_map = {i: [] for i in range(len(parsed_content))} |
|
|
chunk_called_map = {i: [] for i in range(len(parsed_content))} |
|
|
else: |
|
|
self.logger.debug(f"Skipping entity extraction for non-code file: {file_path}") |
|
|
file_declared_entities, file_called_entities = [], [] |
|
|
chunk_declared_map = {i: [] for i in range(len(parsed_content))} |
|
|
chunk_called_map = {i: [] for i in range(len(parsed_content))} |
|
|
|
|
|
chunk_tasks = [] |
|
|
for i, chunk in enumerate(parsed_content): |
|
|
chunk_id = f'{file_path}_{i}' |
|
|
self.logger.debug(f"Scheduling processing for chunk {chunk_id} of file {file_path}") |
|
|
|
|
|
async def process_chunk(i=i, chunk=chunk, chunk_id=chunk_id): |
|
|
self.logger.debug(f"Creating chunk node: {chunk_id}") |
|
|
declared_entities = chunk_declared_map.get(i, []) |
|
|
called_entities = chunk_called_map.get(i, []) |
|
|
|
|
|
|
|
|
|
|
|
temp_alias_map = build_entity_alias_map(self.entities) |
|
|
|
|
|
for entity in declared_entities: |
|
|
name = entity.get("name") |
|
|
if not name: |
|
|
continue |
|
|
|
|
|
|
|
|
entity_aliases = entity.get("aliases", []) |
|
|
canonical_name = None |
|
|
|
|
|
|
|
|
if name in temp_alias_map: |
|
|
canonical_name = temp_alias_map[name] |
|
|
self.logger.debug(f"Entity '{name}' already exists as '{canonical_name}'") |
|
|
else: |
|
|
|
|
|
for alias in entity_aliases: |
|
|
if alias in temp_alias_map: |
|
|
canonical_name = temp_alias_map[alias] |
|
|
self.logger.debug(f"Entity '{name}' matches existing entity '{canonical_name}' via alias '{alias}'") |
|
|
break |
|
|
|
|
|
|
|
|
if canonical_name: |
|
|
entity_key = canonical_name |
|
|
else: |
|
|
entity_key = name |
|
|
self.logger.debug(f"Registering new declared entity '{name}' in chunk {chunk_id}") |
|
|
self.entities[entity_key] = { |
|
|
"declaring_chunk_ids": [], |
|
|
"calling_chunk_ids": [], |
|
|
"type": [], |
|
|
"dtype": None, |
|
|
"aliases": [] |
|
|
} |
|
|
|
|
|
temp_alias_map[entity_key] = entity_key |
|
|
|
|
|
if chunk_id not in self.entities[entity_key]["declaring_chunk_ids"]: |
|
|
self.entities[entity_key]["declaring_chunk_ids"].append(chunk_id) |
|
|
entity_type = entity.get("type") |
|
|
if entity_type and entity_type not in self.entities[entity_key]["type"]: |
|
|
self.entities[entity_key]["type"].append(entity_type) |
|
|
dtype = entity.get("dtype") |
|
|
if dtype: |
|
|
self.entities[entity_key]["dtype"] = dtype |
|
|
|
|
|
for alias in [name] + entity_aliases: |
|
|
if alias and alias not in self.entities[entity_key]["aliases"]: |
|
|
self.entities[entity_key]["aliases"].append(alias) |
|
|
temp_alias_map[alias] = entity_key |
|
|
self.logger.debug(f"Declared entity '{name}' registered as '{entity_key}' in chunk {chunk_id} with aliases: {self.entities[entity_key]['aliases']}") |
|
|
|
|
|
|
|
|
|
|
|
if describe_nodes: |
|
|
self.logger.info(f"Generating description for chunk {chunk_id}") |
|
|
try: |
|
|
description = await self.model_service.query_async( |
|
|
f'Summarize this {get_language_from_filename(file_path)} code chunk in a few sentences: {chunk}') |
|
|
except Exception as e: |
|
|
self.logger.error(f"Error generating description for chunk {chunk_id}: {e}") |
|
|
description = '' |
|
|
else: |
|
|
self.logger.debug(f"No description requested for chunk {chunk_id}") |
|
|
description = '' |
|
|
|
|
|
chunk_node = ChunkNode( |
|
|
id=chunk_id, |
|
|
name=chunk_id, |
|
|
path=file_path, |
|
|
content=chunk, |
|
|
order_in_file=i, |
|
|
called_entities=called_entities, |
|
|
declared_entities=declared_entities, |
|
|
language=get_language_from_filename(file_path), |
|
|
description=description, |
|
|
) |
|
|
self.logger.debug(f"Chunk node created: {chunk_id}") |
|
|
|
|
|
|
|
|
|
|
|
chunk_node.embedding = None |
|
|
return (chunk_id, chunk_node, declared_entities, called_entities) |
|
|
|
|
|
chunk_tasks.append(process_chunk()) |
|
|
|
|
|
chunk_results = await asyncio.gather(*chunk_tasks) |
|
|
self.logger.debug(f"Finished processing {len(chunk_results)} chunks for file {file_path}.") |
|
|
chunk_info[file_path] = { |
|
|
'chunk_results': chunk_results, |
|
|
'file_declared_entities': file_declared_entities, |
|
|
'file_called_entities': file_called_entities |
|
|
} |
|
|
|
|
|
|
|
|
self.logger.info(f"Created {total_chunks} chunk nodes from {len(level1_node_contents)} files") |
|
|
|
|
|
|
|
|
self.logger.info("Starting second pass: resolving called entities using alias map...") |
|
|
alias_map = build_entity_alias_map(self.entities) |
|
|
self.logger.info(f"Built alias map with {len(alias_map)} entries for resolution") |
|
|
|
|
|
resolved_count = 0 |
|
|
for file_path, file_data in tqdm.tqdm(chunk_info.items(), desc="Resolving called entities"): |
|
|
chunk_results = file_data['chunk_results'] |
|
|
for chunk_id, chunk_node, declared_entities, called_entities in chunk_results: |
|
|
for called_name in called_entities: |
|
|
|
|
|
if not called_name or not called_name.strip(): |
|
|
continue |
|
|
|
|
|
|
|
|
resolved_name = resolve_entity_call(called_name, alias_map) |
|
|
|
|
|
|
|
|
if resolved_name: |
|
|
entity_key = resolved_name |
|
|
elif called_name in alias_map: |
|
|
|
|
|
entity_key = alias_map[called_name] |
|
|
else: |
|
|
|
|
|
entity_key = called_name |
|
|
|
|
|
if entity_key not in self.entities: |
|
|
self.logger.debug(f"Registering new called entity '{entity_key}' (called as '{called_name}') in chunk {chunk_id}") |
|
|
self.entities[entity_key] = { |
|
|
"declaring_chunk_ids": [], |
|
|
"calling_chunk_ids": [], |
|
|
"type": [], |
|
|
"dtype": None, |
|
|
"aliases": [] |
|
|
} |
|
|
|
|
|
if called_name != entity_key: |
|
|
self.entities[entity_key]["aliases"].append(called_name) |
|
|
alias_map[called_name] = entity_key |
|
|
|
|
|
if chunk_id not in self.entities[entity_key]["calling_chunk_ids"]: |
|
|
self.entities[entity_key]["calling_chunk_ids"].append(chunk_id) |
|
|
|
|
|
if resolved_name and resolved_name != called_name: |
|
|
resolved_count += 1 |
|
|
self.logger.debug(f"Called entity '{called_name}' resolved to '{entity_key}' in chunk {chunk_id}") |
|
|
|
|
|
self.logger.info(f"Resolved {resolved_count} entity calls to existing declarations via aliases") |
|
|
self.logger.info("All chunk nodes have been created for all files.") |
|
|
return chunk_info |
|
|
|
|
|
def _create_file_nodes(self, chunk_info, level1_node_contents): |
|
|
self.logger.info("Starting file node creation.") |
|
|
""" |
|
|
For each file, aggregate chunk information and create FileNode objects. |
|
|
This method remains mostly the same. |
|
|
""" |
|
|
|
|
|
def merge_entities(target, source): |
|
|
|
|
|
existing = set((e.get('name'), e.get('type')) for e in target) |
|
|
for e in source: |
|
|
k = (e.get('name'), e.get('type')) |
|
|
if k not in existing: |
|
|
target.append(e) |
|
|
existing.add(k) |
|
|
|
|
|
def merge_called_entities(target, source): |
|
|
|
|
|
existing = set(target) |
|
|
for e in source: |
|
|
if e not in existing: |
|
|
target.append(e) |
|
|
existing.add(e) |
|
|
|
|
|
file_info = {} |
|
|
for file_path, file_data in tqdm.tqdm(chunk_info.items(), desc="Creating file nodes"): |
|
|
self.logger.info(f"Creating file node for: {file_path}") |
|
|
parts = os.path.normpath(file_path).split(os.sep) |
|
|
|
|
|
|
|
|
chunk_results = file_data['chunk_results'] |
|
|
file_declared_entities = list(file_data['file_declared_entities']) |
|
|
file_called_entities = list(file_data['file_called_entities']) |
|
|
chunk_ids = [] |
|
|
|
|
|
for chunk_id, chunk_node, declared_entities, called_entities in chunk_results: |
|
|
self.logger.info(f"Adding chunk node {chunk_id} to graph for file {file_path}") |
|
|
self.graph.add_node(chunk_id, data=chunk_node, level=2) |
|
|
chunk_ids.append(chunk_id) |
|
|
|
|
|
|
|
|
|
|
|
file_node = FileNode( |
|
|
id=file_path, |
|
|
name=parts[-1], |
|
|
path=file_path, |
|
|
node_type='file', |
|
|
content=level1_node_contents[file_path], |
|
|
declared_entities=file_declared_entities, |
|
|
called_entities=file_called_entities, |
|
|
language=get_language_from_filename(file_path), |
|
|
) |
|
|
|
|
|
self.logger.debug(f"Adding file node {file_path} to graph.") |
|
|
self.graph.add_node(file_path, data=file_node, level=1) |
|
|
for chunk_id in chunk_ids: |
|
|
self.graph.add_edge(file_path, chunk_id, relation='contains') |
|
|
|
|
|
file_info[file_path] = { |
|
|
'declared_entities': file_declared_entities, |
|
|
'called_entities': file_called_entities, |
|
|
'chunk_ids': chunk_ids, |
|
|
'parts': parts, |
|
|
} |
|
|
self.logger.info(f"File node {file_path} added to graph with {len(chunk_ids)} chunks.") |
|
|
|
|
|
self.logger.info("All file nodes have been created.") |
|
|
return file_info |
|
|
|
|
|
def _create_directory_nodes(self, file_info): |
|
|
self.logger.info("Starting directory node creation.") |
|
|
""" |
|
|
For each directory, aggregate file information and create DirectoryNode objects. |
|
|
|
|
|
Args: |
|
|
file_info (dict): Mapping file_path -> file info dict. |
|
|
|
|
|
Returns: |
|
|
dict: Mapping dir_path -> aggregated entity info. |
|
|
""" |
|
|
|
|
|
def merge_entities(target, source): |
|
|
|
|
|
existing = set((e.get('name'), e.get('type')) for e in target) |
|
|
for e in source: |
|
|
k = (e.get('name'), e.get('type')) |
|
|
if k not in existing: |
|
|
target.append(e) |
|
|
existing.add(k) |
|
|
|
|
|
def merge_called_entities(target, source): |
|
|
|
|
|
existing = set(target) |
|
|
for e in source: |
|
|
if e not in existing: |
|
|
target.append(e) |
|
|
existing.add(e) |
|
|
|
|
|
dir_agg = {} |
|
|
for file_path, info in tqdm.tqdm(file_info.items(), desc="Creating directory nodes"): |
|
|
self.logger.info(f"Processing directory nodes for file: {file_path}") |
|
|
parts = os.path.normpath(file_path).split(os.sep) |
|
|
file_declared_entities = info['declared_entities'] |
|
|
file_called_entities = info['called_entities'] |
|
|
current_parent = 'root' |
|
|
path_accum = '' |
|
|
for part in parts[:-1]: |
|
|
path_accum = os.path.join(path_accum, part) if path_accum else part |
|
|
if path_accum not in self.graph: |
|
|
self.logger.info(f"Adding new directory node: {path_accum}") |
|
|
dir_node = DirectoryNode(id=path_accum, name=part, path=path_accum) |
|
|
self.graph.add_node(path_accum, data=dir_node, level=1) |
|
|
self.graph.add_edge(current_parent, path_accum, relation='contains') |
|
|
if path_accum not in dir_agg: |
|
|
dir_agg[path_accum] = {'declared_entities': [], 'called_entities': []} |
|
|
merge_entities(dir_agg[path_accum]['declared_entities'], file_declared_entities) |
|
|
merge_called_entities(dir_agg[path_accum]['called_entities'], file_called_entities) |
|
|
current_parent = path_accum |
|
|
|
|
|
self.graph.add_edge(current_parent, file_path, relation='contains') |
|
|
self.logger.info("All directory nodes created.") |
|
|
return dir_agg |
|
|
|
|
|
def _aggregate_to_root(self, dir_agg): |
|
|
self.logger.info("Aggregating directory information to root node.") |
|
|
""" |
|
|
Aggregate all directory entity information to the root node. |
|
|
|
|
|
Args: |
|
|
dir_agg (dict): Mapping dir_path -> aggregated entity info. |
|
|
""" |
|
|
|
|
|
def merge_entities(target, source): |
|
|
|
|
|
existing = set((e.get('name'), e.get('type')) for e in target) |
|
|
for e in source: |
|
|
k = (e.get('name'), e.get('type')) |
|
|
if k not in existing: |
|
|
target.append(e) |
|
|
existing.add(k) |
|
|
|
|
|
def merge_called_entities(target, source): |
|
|
|
|
|
existing = set(target) |
|
|
for e in source: |
|
|
if e not in existing: |
|
|
target.append(e) |
|
|
existing.add(e) |
|
|
|
|
|
root_node = Node(id='root', name='root', node_type='root') |
|
|
self.graph.add_node('root', data=root_node, level=0) |
|
|
root_declared_entities = [] |
|
|
root_called_entities = [] |
|
|
for dir_path, agg in tqdm.tqdm(dir_agg.items(), desc="Aggregating to root"): |
|
|
node = self.graph.nodes[dir_path]['data'] |
|
|
if not hasattr(node, 'declared_entities'): |
|
|
node.declared_entities = [] |
|
|
if not hasattr(node, 'called_entities'): |
|
|
node.called_entities = [] |
|
|
merge_entities(node.declared_entities, agg['declared_entities']) |
|
|
merge_called_entities(node.called_entities, agg['called_entities']) |
|
|
merge_entities(root_declared_entities, agg['declared_entities']) |
|
|
merge_called_entities(root_called_entities, agg['called_entities']) |
|
|
if not hasattr(root_node, 'declared_entities'): |
|
|
root_node.declared_entities = [] |
|
|
if not hasattr(root_node, 'called_entities'): |
|
|
root_node.called_entities = [] |
|
|
merge_entities(root_node.declared_entities, root_declared_entities) |
|
|
merge_called_entities(root_node.called_entities, root_called_entities) |
|
|
self.logger.info("Aggregation to root node complete.") |
|
|
|
|
|
def _build_relationships(self): |
|
|
self.logger.info("Building relationships between chunk nodes based on entities.") |
|
|
""" |
|
|
Build relationships between chunk nodes and entity nodes based on self.entities. |
|
|
For each entity in self.entities: |
|
|
1. Create an EntityNode with entity_name as the id |
|
|
2. Create edges from declaring chunks to entity node (declares relationship) |
|
|
3. Create edges from entity node to calling chunks (called_by relationship) |
|
|
4. Resolve called entity names using aliases for better matching |
|
|
""" |
|
|
from .Node import EntityNode |
|
|
edges_created = 0 |
|
|
entity_nodes_created = 0 |
|
|
|
|
|
|
|
|
self.logger.info("Building entity alias map for call resolution...") |
|
|
alias_map = build_entity_alias_map(self.entities) |
|
|
self.logger.info(f"Built alias map with {len(alias_map)} entries") |
|
|
|
|
|
|
|
|
for entity_name, info in tqdm.tqdm(self.entities.items(), desc="Creating entity nodes"): |
|
|
|
|
|
entity_types = info.get('type', []) |
|
|
entity_type = entity_types[0] if entity_types else '' |
|
|
declaring_chunks = info.get('declaring_chunk_ids', []) |
|
|
calling_chunks = info.get('calling_chunk_ids', []) |
|
|
aliases = info.get('aliases', []) |
|
|
|
|
|
|
|
|
entity_node = EntityNode( |
|
|
id=entity_name, |
|
|
name=entity_name, |
|
|
entity_type=entity_type, |
|
|
declaring_chunk_ids=declaring_chunks, |
|
|
calling_chunk_ids=calling_chunks, |
|
|
aliases=aliases |
|
|
) |
|
|
|
|
|
|
|
|
self.graph.add_node(entity_name, data=entity_node, level=3) |
|
|
entity_nodes_created += 1 |
|
|
|
|
|
|
|
|
if aliases: |
|
|
self.logger.debug(f"Created EntityNode '{entity_name}' with aliases: {aliases}") |
|
|
|
|
|
|
|
|
for declarer_id in declaring_chunks: |
|
|
if declarer_id in self.graph: |
|
|
self.graph.add_edge(declarer_id, entity_name, relation='declares') |
|
|
edges_created += 1 |
|
|
|
|
|
|
|
|
for caller_id in calling_chunks: |
|
|
if caller_id in self.graph and caller_id not in declaring_chunks: |
|
|
self.graph.add_edge(entity_name, caller_id, relation='called_by') |
|
|
edges_created += 1 |
|
|
|
|
|
|
|
|
self.logger.info("Resolving entity calls using alias matching...") |
|
|
resolved_calls = 0 |
|
|
|
|
|
for entity_name, info in tqdm.tqdm(self.entities.items(), desc="Resolving entity calls"): |
|
|
|
|
|
if info.get('declaring_chunk_ids'): |
|
|
continue |
|
|
|
|
|
|
|
|
resolved_name = resolve_entity_call(entity_name, alias_map) |
|
|
|
|
|
if resolved_name and resolved_name != entity_name: |
|
|
|
|
|
calling_chunks = info.get('calling_chunk_ids', []) |
|
|
|
|
|
if resolved_name in self.entities: |
|
|
for caller_id in calling_chunks: |
|
|
if caller_id in self.graph: |
|
|
|
|
|
if not self.graph.has_edge(resolved_name, caller_id): |
|
|
self.graph.add_edge(resolved_name, caller_id, relation='called_by') |
|
|
edges_created += 1 |
|
|
resolved_calls += 1 |
|
|
self.logger.debug(f"Resolved call: '{entity_name}' -> '{resolved_name}' in chunk {caller_id}") |
|
|
|
|
|
self.logger.info(f"_build_relationships: Created {entity_nodes_created} entity nodes, " |
|
|
f"{edges_created} edges, and resolved {resolved_calls} entity calls using aliases.") |
|
|
|
|
|
def get_entity_by_alias(self, alias: str) -> Optional[str]: |
|
|
""" |
|
|
Get the canonical entity name for a given alias. |
|
|
|
|
|
Args: |
|
|
alias: An alias of an entity (e.g., 'MyClass' or 'module.MyClass') |
|
|
|
|
|
Returns: |
|
|
Canonical entity name if found, None otherwise |
|
|
""" |
|
|
alias_map = build_entity_alias_map(self.entities) |
|
|
return alias_map.get(alias) |
|
|
|
|
|
def resolve_entity_references(self) -> Dict[str, List[str]]: |
|
|
""" |
|
|
Resolve all entity references in the knowledge graph using aliases. |
|
|
Returns a mapping of unresolved entity calls to their potential matches. |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping called entity names to list of potential canonical matches |
|
|
""" |
|
|
alias_map = build_entity_alias_map(self.entities) |
|
|
resolutions = {} |
|
|
|
|
|
for entity_name, info in self.entities.items(): |
|
|
|
|
|
if not info.get('declaring_chunk_ids') and info.get('calling_chunk_ids'): |
|
|
resolved = resolve_entity_call(entity_name, alias_map) |
|
|
if resolved: |
|
|
resolutions[entity_name] = resolved |
|
|
|
|
|
return resolutions |
|
|
|
|
|
def print_tree(self, max_depth=None, start_node_id='root', level=0, prefix=""): |
|
|
""" |
|
|
Print the repository tree structure using the graph with 'contains' edges. |
|
|
|
|
|
Args: |
|
|
max_depth (int, optional): Maximum depth to print. None = unlimited. |
|
|
start_node_id (str): ID of the node to start from. Default is 'root'. |
|
|
level (int): Internal use only (used for recursion). |
|
|
prefix (str): Internal use only (used for formatting output). |
|
|
""" |
|
|
if max_depth is not None and level > max_depth: |
|
|
self.logger.debug(f"Max depth {max_depth} reached at node {start_node_id}.") |
|
|
return |
|
|
|
|
|
if start_node_id not in self.graph: |
|
|
self.logger.warning(f"Start node '{start_node_id}' not found in graph.") |
|
|
return |
|
|
|
|
|
try: |
|
|
node_data = self[start_node_id] |
|
|
except KeyError as e: |
|
|
self.logger.error(f"KeyError when accessing node {start_node_id}: {e}") |
|
|
self.logger.error(f"Available node attributes: {list(self.graph.nodes[start_node_id].keys())}") |
|
|
|
|
|
if 'data' not in self.graph.nodes[start_node_id]: |
|
|
self.logger.warning(f"Node {start_node_id} has no 'data' attribute, using node itself") |
|
|
|
|
|
if start_node_id == 'root': |
|
|
|
|
|
node_data = Node(id='root', name='root', node_type='root') |
|
|
|
|
|
self.graph.nodes[start_node_id]['data'] = node_data |
|
|
else: |
|
|
|
|
|
name = start_node_id.split('/')[-1] if '/' in start_node_id else start_node_id |
|
|
if '_' in start_node_id and start_node_id.split('_')[-1].isdigit(): |
|
|
|
|
|
node_data = ChunkNode(id=start_node_id, name=name, node_type='chunk') |
|
|
elif '.' in name: |
|
|
|
|
|
node_data = FileNode(id=start_node_id, name=name, node_type='file', path=start_node_id) |
|
|
else: |
|
|
|
|
|
node_data = DirectoryNode(id=start_node_id, name=name, node_type='directory', |
|
|
path=start_node_id) |
|
|
|
|
|
self.graph.nodes[start_node_id]['data'] = node_data |
|
|
return |
|
|
|
|
|
|
|
|
if node_data.node_type == 'file': |
|
|
node_symbol = "π" |
|
|
elif node_data.node_type == 'chunk': |
|
|
node_symbol = "π" |
|
|
elif node_data.node_type == 'root': |
|
|
node_symbol = "π" |
|
|
elif node_data.node_type == 'directory': |
|
|
node_symbol = "π" |
|
|
else: |
|
|
node_symbol = "π¦" |
|
|
|
|
|
if level == 0: |
|
|
print(f"{node_symbol} {node_data.name} ({node_data.node_type})") |
|
|
else: |
|
|
print(f"{prefix}βββ {node_symbol} {node_data.name} ({node_data.node_type})") |
|
|
|
|
|
|
|
|
children = [ |
|
|
child for child in self.graph.successors(start_node_id) |
|
|
if self.graph.edges[start_node_id, child].get('relation') == 'contains' |
|
|
] |
|
|
|
|
|
child_count = len(children) |
|
|
for i, child_id in enumerate(children): |
|
|
is_last = i == child_count - 1 |
|
|
new_prefix = prefix + (" " if is_last else "β ") |
|
|
self.print_tree(max_depth, start_node_id=child_id, level=level + 1, prefix=new_prefix) |
|
|
|
|
|
def to_dict(self): |
|
|
self.logger.info("Serializing graph to dictionary.") |
|
|
from .Node import EntityNode |
|
|
graph_data = { |
|
|
'nodes': [], |
|
|
'edges': [] |
|
|
} |
|
|
|
|
|
for node_id, node_attrs in tqdm.tqdm(self.graph.nodes(data=True), desc="Serializing nodes"): |
|
|
if 'data' not in node_attrs: |
|
|
self.logger.warning(f"Node {node_id} has no 'data' attribute, skipping in serialization") |
|
|
continue |
|
|
|
|
|
node = node_attrs['data'] |
|
|
node_dict = { |
|
|
'id': node.id or node_id, |
|
|
'class': node.__class__.__name__, |
|
|
'data': { |
|
|
'id': node.id or node_id, |
|
|
'name': node.name, |
|
|
'node_type': node.node_type, |
|
|
'description': getattr(node, 'description', ''), |
|
|
'declared_entities': list(getattr(node, 'declared_entities', [])), |
|
|
'called_entities': list(getattr(node, 'called_entities', [])), |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if isinstance(node, FileNode): |
|
|
node_dict['data']['path'] = node.path |
|
|
node_dict['data']['content'] = node.content |
|
|
node_dict['data']['language'] = getattr(node, 'language', '') |
|
|
|
|
|
|
|
|
if isinstance(node, ChunkNode): |
|
|
node_dict['data']['order_in_file'] = getattr(node, 'order_in_file', 0) |
|
|
node_dict['data']['embedding'] = getattr(node, 'embedding', None) |
|
|
|
|
|
|
|
|
if isinstance(node, EntityNode): |
|
|
node_dict['data']['entity_type'] = getattr(node, 'entity_type', '') |
|
|
node_dict['data']['declaring_chunk_ids'] = list(getattr(node, 'declaring_chunk_ids', [])) |
|
|
node_dict['data']['calling_chunk_ids'] = list(getattr(node, 'calling_chunk_ids', [])) |
|
|
node_dict['data']['aliases'] = list(getattr(node, 'aliases', [])) |
|
|
|
|
|
graph_data['nodes'].append(node_dict) |
|
|
|
|
|
for u, v, attrs in tqdm.tqdm(self.graph.edges(data=True), desc="Serializing edges"): |
|
|
edge_data = { |
|
|
'source': u, |
|
|
'target': v, |
|
|
'relation': attrs.get('relation', '') |
|
|
} |
|
|
if 'entities' in attrs: |
|
|
edge_data['entities'] = list(attrs['entities']) |
|
|
graph_data['edges'].append(edge_data) |
|
|
|
|
|
self.logger.info("Serialization complete.") |
|
|
return graph_data |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, data_dict, index_nodes: bool = True, use_embed: bool = True, |
|
|
model_service_kwargs: Optional[dict] = None, code_index_kwargs: Optional[dict] = None): |
|
|
|
|
|
instance = cls.__new__(cls) |
|
|
instance._initialize(model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs) |
|
|
|
|
|
instance.logger.info("Deserializing graph from dictionary.") |
|
|
|
|
|
|
|
|
node_classes = { |
|
|
'Node': Node, |
|
|
'FileNode': FileNode, |
|
|
'ChunkNode': ChunkNode, |
|
|
'DirectoryNode': DirectoryNode, |
|
|
'EntityNode': EntityNode, |
|
|
} |
|
|
|
|
|
|
|
|
root_found = any(node_data['id'] == 'root' for node_data in data_dict['nodes']) |
|
|
if not root_found: |
|
|
instance.logger.warning("Root node not found in the data, creating one") |
|
|
root_node = Node(id='root', name='root', node_type='root') |
|
|
instance.graph.add_node('root', data=root_node, level=0) |
|
|
|
|
|
|
|
|
for node_data in tqdm.tqdm(data_dict['nodes'], desc="Rebuilding nodes"): |
|
|
cls_name = node_data['class'] |
|
|
node_cls = node_classes.get(cls_name, Node) |
|
|
kwargs = node_data['data'] |
|
|
|
|
|
|
|
|
if not kwargs.get('id'): |
|
|
kwargs['id'] = node_data['id'] |
|
|
|
|
|
|
|
|
kwargs['declared_entities'] = list(kwargs.get('declared_entities', [])) |
|
|
kwargs['called_entities'] = list(kwargs.get('called_entities', [])) |
|
|
|
|
|
|
|
|
if node_cls in (FileNode, ChunkNode): |
|
|
kwargs.setdefault('path', '') |
|
|
kwargs.setdefault('content', '') |
|
|
kwargs.setdefault('language', '') |
|
|
if node_cls == ChunkNode: |
|
|
kwargs.setdefault('order_in_file', 0) |
|
|
kwargs.setdefault('embedding', []) |
|
|
|
|
|
if node_cls == EntityNode: |
|
|
kwargs.setdefault('entity_type', '') |
|
|
kwargs.setdefault('declaring_chunk_ids', []) |
|
|
kwargs.setdefault('calling_chunk_ids', []) |
|
|
kwargs.setdefault('aliases', []) |
|
|
|
|
|
node_instance = node_cls(**kwargs) |
|
|
instance.graph.add_node(node_data['id'], data=node_instance, level=instance._infer_level(node_instance)) |
|
|
|
|
|
|
|
|
for edge in tqdm.tqdm(data_dict['edges'], desc="Rebuilding edges"): |
|
|
source = edge['source'] |
|
|
target = edge['target'] |
|
|
if source in instance.graph and target in instance.graph: |
|
|
edge_kwargs = {'relation': edge.get('relation', '')} |
|
|
if 'entities' in edge: |
|
|
edge_kwargs['entities'] = list(edge['entities']) |
|
|
instance.graph.add_edge(source, target, **edge_kwargs) |
|
|
else: |
|
|
instance.logger.warning(f"Cannot add edge {source} -> {target}, nodes don't exist") |
|
|
|
|
|
|
|
|
instance.entities = {} |
|
|
for node_id, node_attrs in tqdm.tqdm(instance.graph.nodes(data=True), desc="Rebuilding entities"): |
|
|
node = node_attrs['data'] |
|
|
declared_entities = getattr(node, 'declared_entities', []) |
|
|
called_entities = getattr(node, 'called_entities', []) |
|
|
for entity in declared_entities: |
|
|
if isinstance(entity, dict): |
|
|
name = entity.get('name') |
|
|
else: |
|
|
name = entity |
|
|
if not name: |
|
|
continue |
|
|
if name not in instance.entities: |
|
|
instance.entities[name] = { |
|
|
"declaring_chunk_ids": [], |
|
|
"calling_chunk_ids": [], |
|
|
"type": [], |
|
|
"dtype": None |
|
|
} |
|
|
|
|
|
if node_id not in instance.entities[name]["declaring_chunk_ids"]: |
|
|
if node_id in instance.graph and isinstance(instance.graph.nodes[node_id]["data"], ChunkNode): |
|
|
instance.entities[name]["declaring_chunk_ids"].append(node_id) |
|
|
if isinstance(entity, dict): |
|
|
entity_type = entity.get("type") |
|
|
if entity_type and entity_type not in instance.entities[name]["type"]: |
|
|
instance.entities[name]["type"].append(entity_type) |
|
|
dtype = entity.get("dtype") |
|
|
if dtype: |
|
|
instance.entities[name]["dtype"] = dtype |
|
|
for called_name in called_entities: |
|
|
if not called_name: |
|
|
continue |
|
|
if called_name not in instance.entities: |
|
|
instance.entities[called_name] = { |
|
|
"declaring_chunk_ids": [], |
|
|
"calling_chunk_ids": [], |
|
|
"type": [], |
|
|
"dtype": None |
|
|
} |
|
|
if node_id not in instance.entities[called_name]["calling_chunk_ids"]: |
|
|
if node_id in instance.graph and isinstance(instance.graph.nodes[node_id]["data"], ChunkNode): |
|
|
instance.entities[called_name]["calling_chunk_ids"].append(node_id) |
|
|
|
|
|
if index_nodes: |
|
|
instance.logger.info("Building code index after deserialization.") |
|
|
|
|
|
code_idx_kwargs = code_index_kwargs or {} |
|
|
if 'use_embed' not in code_idx_kwargs: |
|
|
code_idx_kwargs['use_embed'] = use_embed |
|
|
instance.code_index = CodeIndex(list(instance), model_service=instance.model_service, **code_idx_kwargs) |
|
|
|
|
|
instance.logger.info("Deserialization complete.") |
|
|
return instance |
|
|
|
|
|
def _infer_level(self, node): |
|
|
"""Infer the level of a node based on its type""" |
|
|
if node.node_type == 'root': |
|
|
return 0 |
|
|
elif node.node_type in ('file', 'directory'): |
|
|
return 1 |
|
|
elif node.node_type == 'chunk': |
|
|
return 2 |
|
|
return 1 |
|
|
|
|
|
def save_graph_to_file(self, filepath: str): |
|
|
self.logger.info(f"Saving graph to file: {filepath}") |
|
|
with open(filepath, 'w') as f: |
|
|
json.dump(self.to_dict(), f, indent=2) |
|
|
self.logger.info("Graph saved successfully.") |
|
|
|
|
|
@classmethod |
|
|
def load_graph_from_file(cls, filepath: str, index_nodes=True, use_embed: bool = True, |
|
|
model_service_kwargs: Optional[dict] = None, code_index_kwargs: Optional[dict] = None): |
|
|
if model_service_kwargs is None: |
|
|
model_service_kwargs = {} |
|
|
with open(filepath, 'r') as f: |
|
|
data = json.load(f) |
|
|
logging.getLogger(LOGGER_NAME).info(f"Loaded graph data from file: {filepath}") |
|
|
return cls.from_dict(data, use_embed=use_embed, index_nodes=index_nodes, |
|
|
model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs) |
|
|
|
|
|
def to_hf_dataset( |
|
|
self, |
|
|
repo_id: str, |
|
|
save_embeddings: bool = True, |
|
|
private: bool = False, |
|
|
token: Optional[str] = None, |
|
|
commit_message: Optional[str] = None, |
|
|
): |
|
|
""" |
|
|
Save the knowledge graph to a HuggingFace dataset on the Hub. |
|
|
|
|
|
The graph is serialized into two splits: |
|
|
- 'nodes': Contains all node data |
|
|
- 'edges': Contains all edge relationships |
|
|
|
|
|
Args: |
|
|
repo_id (str): The HuggingFace dataset repository ID (e.g., 'username/dataset-name') |
|
|
save_embeddings (bool): If True, saves embedding vectors for chunk nodes. |
|
|
If False, embeddings are excluded to reduce dataset size. |
|
|
private (bool): Whether the dataset should be private. Defaults to False. |
|
|
token (str, optional): HuggingFace API token. If not provided, uses the token |
|
|
from huggingface_hub login or HF_TOKEN environment variable. |
|
|
commit_message (str, optional): Custom commit message for the upload. |
|
|
|
|
|
Returns: |
|
|
str: URL of the uploaded dataset |
|
|
""" |
|
|
try: |
|
|
from datasets import Dataset, DatasetDict |
|
|
from huggingface_hub import HfApi |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"huggingface_hub and datasets are required for HuggingFace integration. " |
|
|
"Install them with: pip install huggingface_hub datasets" |
|
|
) |
|
|
|
|
|
self.logger.info(f"Preparing to save knowledge graph to HuggingFace dataset: {repo_id}") |
|
|
self.logger.info(f"save_embeddings={save_embeddings}") |
|
|
|
|
|
|
|
|
nodes_data = [] |
|
|
for node_id, node_attrs in tqdm.tqdm(self.graph.nodes(data=True), desc="Serializing nodes for HF dataset"): |
|
|
if 'data' not in node_attrs: |
|
|
self.logger.warning(f"Node {node_id} has no 'data' attribute, skipping") |
|
|
continue |
|
|
|
|
|
node = node_attrs['data'] |
|
|
node_record = { |
|
|
'node_id': node.id or node_id, |
|
|
'node_class': node.__class__.__name__, |
|
|
'name': node.name, |
|
|
'node_type': node.node_type, |
|
|
'description': getattr(node, 'description', '') or '', |
|
|
'declared_entities': json.dumps(list(getattr(node, 'declared_entities', []))), |
|
|
'called_entities': json.dumps(list(getattr(node, 'called_entities', []))), |
|
|
} |
|
|
|
|
|
|
|
|
if isinstance(node, FileNode): |
|
|
node_record['path'] = node.path |
|
|
node_record['content'] = node.content |
|
|
node_record['language'] = getattr(node, 'language', '') |
|
|
else: |
|
|
node_record['path'] = '' |
|
|
node_record['content'] = '' |
|
|
node_record['language'] = '' |
|
|
|
|
|
|
|
|
if isinstance(node, ChunkNode): |
|
|
node_record['order_in_file'] = getattr(node, 'order_in_file', 0) |
|
|
if save_embeddings: |
|
|
embedding = getattr(node, 'embedding', None) |
|
|
node_record['embedding'] = json.dumps(embedding if embedding is not None else []) |
|
|
else: |
|
|
node_record['embedding'] = json.dumps([]) |
|
|
else: |
|
|
node_record['order_in_file'] = -1 |
|
|
node_record['embedding'] = json.dumps([]) |
|
|
|
|
|
|
|
|
if isinstance(node, EntityNode): |
|
|
node_record['entity_type'] = getattr(node, 'entity_type', '') |
|
|
node_record['declaring_chunk_ids'] = json.dumps(list(getattr(node, 'declaring_chunk_ids', []))) |
|
|
node_record['calling_chunk_ids'] = json.dumps(list(getattr(node, 'calling_chunk_ids', []))) |
|
|
node_record['aliases'] = json.dumps(list(getattr(node, 'aliases', []))) |
|
|
else: |
|
|
node_record['entity_type'] = '' |
|
|
node_record['declaring_chunk_ids'] = json.dumps([]) |
|
|
node_record['calling_chunk_ids'] = json.dumps([]) |
|
|
node_record['aliases'] = json.dumps([]) |
|
|
|
|
|
nodes_data.append(node_record) |
|
|
|
|
|
|
|
|
edges_data = [] |
|
|
for source, target, attrs in tqdm.tqdm(self.graph.edges(data=True), desc="Serializing edges for HF dataset"): |
|
|
edge_record = { |
|
|
'source': source, |
|
|
'target': target, |
|
|
'relation': attrs.get('relation', ''), |
|
|
'entities': json.dumps(list(attrs.get('entities', []))) if 'entities' in attrs else json.dumps([]) |
|
|
} |
|
|
edges_data.append(edge_record) |
|
|
|
|
|
|
|
|
nodes_dataset = Dataset.from_list(nodes_data) |
|
|
edges_dataset = Dataset.from_list(edges_data) |
|
|
|
|
|
self.logger.info(f"Created dataset with {len(nodes_data)} nodes and {len(edges_data)} edges") |
|
|
|
|
|
|
|
|
|
|
|
if commit_message is None: |
|
|
base_commit_message = f"Upload knowledge graph ({len(nodes_data)} nodes, {len(edges_data)} edges)" |
|
|
if not save_embeddings: |
|
|
base_commit_message += " [embeddings excluded]" |
|
|
else: |
|
|
base_commit_message = commit_message |
|
|
|
|
|
self.logger.info(f"Pushing nodes dataset to HuggingFace Hub: {repo_id}") |
|
|
nodes_dataset.push_to_hub( |
|
|
repo_id=repo_id, |
|
|
config_name="nodes", |
|
|
private=private, |
|
|
token=token, |
|
|
commit_message=f"{base_commit_message} - nodes" |
|
|
) |
|
|
|
|
|
self.logger.info(f"Pushing edges dataset to HuggingFace Hub: {repo_id}") |
|
|
edges_dataset.push_to_hub( |
|
|
repo_id=repo_id, |
|
|
config_name="edges", |
|
|
private=private, |
|
|
token=token, |
|
|
commit_message=f"{base_commit_message} - edges" |
|
|
) |
|
|
|
|
|
url = f"https://huggingface.co/datasets/{repo_id}" |
|
|
self.logger.info(f"Dataset successfully uploaded to: {url}") |
|
|
return url |
|
|
|
|
|
@classmethod |
|
|
def from_hf_dataset( |
|
|
cls, |
|
|
repo_id: str, |
|
|
index_nodes: bool = True, |
|
|
use_embed: bool = True, |
|
|
model_service_kwargs: Optional[dict] = None, |
|
|
code_index_kwargs: Optional[dict] = None, |
|
|
token: Optional[str] = None, |
|
|
revision: Optional[str] = None, |
|
|
): |
|
|
""" |
|
|
Load a knowledge graph from a HuggingFace dataset on the Hub. |
|
|
|
|
|
Args: |
|
|
repo_id (str): The HuggingFace dataset repository ID (e.g., 'username/dataset-name') |
|
|
index_nodes (bool): Whether to build a code index after loading. Defaults to True. |
|
|
use_embed (bool): Whether to use existing embeddings from the dataset. Defaults to True. |
|
|
model_service_kwargs (dict, optional): Arguments for the model service. |
|
|
code_index_kwargs (dict, optional): Arguments for the code index. |
|
|
token (str, optional): HuggingFace API token for private datasets. |
|
|
revision (str, optional): Git revision (branch, tag, or commit) to load from. |
|
|
|
|
|
Returns: |
|
|
RepoKnowledgeGraph: The loaded knowledge graph instance. |
|
|
""" |
|
|
try: |
|
|
from datasets import load_dataset |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"datasets library is required for HuggingFace integration. " |
|
|
"Install it with: pip install datasets" |
|
|
) |
|
|
|
|
|
if model_service_kwargs is None: |
|
|
model_service_kwargs = {} |
|
|
|
|
|
logger = logging.getLogger(LOGGER_NAME) |
|
|
logger.info(f"Loading knowledge graph from HuggingFace dataset: {repo_id}") |
|
|
|
|
|
|
|
|
logger.info("Loading nodes config...") |
|
|
nodes_dataset = load_dataset(repo_id, name="nodes", token=token, revision=revision) |
|
|
logger.info("Loading edges config...") |
|
|
edges_dataset = load_dataset(repo_id, name="edges", token=token, revision=revision) |
|
|
|
|
|
|
|
|
nodes_data = nodes_dataset['train'] |
|
|
edges_data = edges_dataset['train'] |
|
|
|
|
|
logger.info(f"Loaded {len(nodes_data)} nodes and {len(edges_data)} edges from dataset") |
|
|
|
|
|
|
|
|
graph_data = { |
|
|
'nodes': [], |
|
|
'edges': [] |
|
|
} |
|
|
|
|
|
|
|
|
for record in tqdm.tqdm(nodes_data, desc="Reconstructing nodes from HF dataset"): |
|
|
node_dict = { |
|
|
'id': record['node_id'], |
|
|
'class': record['node_class'], |
|
|
'data': { |
|
|
'id': record['node_id'], |
|
|
'name': record['name'], |
|
|
'node_type': record['node_type'], |
|
|
'description': record['description'], |
|
|
'declared_entities': json.loads(record['declared_entities']), |
|
|
'called_entities': json.loads(record['called_entities']), |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if record['node_class'] in ('FileNode', 'ChunkNode'): |
|
|
node_dict['data']['path'] = record['path'] |
|
|
node_dict['data']['content'] = record['content'] |
|
|
node_dict['data']['language'] = record['language'] |
|
|
|
|
|
|
|
|
if record['node_class'] == 'ChunkNode': |
|
|
node_dict['data']['order_in_file'] = record['order_in_file'] |
|
|
embedding = json.loads(record['embedding']) |
|
|
|
|
|
if use_embed and embedding: |
|
|
node_dict['data']['embedding'] = embedding |
|
|
else: |
|
|
node_dict['data']['embedding'] = [] |
|
|
|
|
|
|
|
|
if record['node_class'] == 'EntityNode': |
|
|
node_dict['data']['entity_type'] = record['entity_type'] |
|
|
node_dict['data']['declaring_chunk_ids'] = json.loads(record['declaring_chunk_ids']) |
|
|
node_dict['data']['calling_chunk_ids'] = json.loads(record['calling_chunk_ids']) |
|
|
node_dict['data']['aliases'] = json.loads(record['aliases']) |
|
|
|
|
|
graph_data['nodes'].append(node_dict) |
|
|
|
|
|
|
|
|
for record in tqdm.tqdm(edges_data, desc="Reconstructing edges from HF dataset"): |
|
|
edge_dict = { |
|
|
'source': record['source'], |
|
|
'target': record['target'], |
|
|
'relation': record['relation'], |
|
|
} |
|
|
entities = json.loads(record['entities']) |
|
|
if entities: |
|
|
edge_dict['entities'] = entities |
|
|
|
|
|
graph_data['edges'].append(edge_dict) |
|
|
|
|
|
logger.info("Dataset reconstruction complete, building graph...") |
|
|
|
|
|
|
|
|
return cls.from_dict( |
|
|
graph_data, |
|
|
index_nodes=index_nodes, |
|
|
use_embed=use_embed, |
|
|
model_service_kwargs=model_service_kwargs, |
|
|
code_index_kwargs=code_index_kwargs |
|
|
) |
|
|
|
|
|
def get_neighbors(self, node_id): |
|
|
self.logger.debug(f"Getting neighbors for node: {node_id}") |
|
|
|
|
|
neighbors = set() |
|
|
for n in self.graph.successors(node_id): |
|
|
neighbors.add(n) |
|
|
for n in self.graph.predecessors(node_id): |
|
|
neighbors.add(n) |
|
|
|
|
|
for u, v in self.graph.edges(node_id): |
|
|
if u == node_id: |
|
|
neighbors.add(v) |
|
|
else: |
|
|
neighbors.add(u) |
|
|
for u, v in self.graph.in_edges(node_id): |
|
|
if v == node_id: |
|
|
neighbors.add(u) |
|
|
else: |
|
|
neighbors.add(v) |
|
|
return [self.graph.nodes[n]['data'] for n in neighbors if 'data' in self.graph.nodes[n]] |
|
|
|
|
|
def get_previous_chunk(self, node_id: str) -> ChunkNode: |
|
|
self.logger.debug(f"Getting previous chunk for node: {node_id}") |
|
|
node = self[node_id] |
|
|
|
|
|
if not isinstance(node, ChunkNode): |
|
|
raise Exception(f'Cannot get previous chunk on node of type {type(node)}') |
|
|
|
|
|
if node.order_in_file == 0: |
|
|
self.logger.warning(f'Cannot get previous chunk for first node') |
|
|
return None |
|
|
|
|
|
file_path = node.path |
|
|
previous_chunk_id = f'{file_path}_{node.order_in_file - 1}' |
|
|
|
|
|
if previous_chunk_id not in self.graph: |
|
|
raise Exception(f'Previous chunk {previous_chunk_id} not found in graph') |
|
|
|
|
|
previous_chunk = self[previous_chunk_id] |
|
|
return previous_chunk |
|
|
|
|
|
def get_next_chunk(self, node_id: str) -> ChunkNode: |
|
|
self.logger.debug(f"Getting next chunk for node: {node_id}") |
|
|
node = self[node_id] |
|
|
|
|
|
if not isinstance(node, ChunkNode): |
|
|
raise Exception(f'Cannot get previous chunk on node of type {type(node)}') |
|
|
|
|
|
file_path = node.path |
|
|
next_chunk_id = f'{file_path}_{node.order_in_file + 1}' |
|
|
|
|
|
if next_chunk_id not in self.graph: |
|
|
self.logger.warning(f'Next chunk {next_chunk_id} not found in graph, it might be the last chunk') |
|
|
return None |
|
|
previous_chunk = self[next_chunk_id] |
|
|
return previous_chunk |
|
|
|
|
|
def get_all_chunks(self) -> List[ChunkNode]: |
|
|
self.logger.debug("Getting all chunk nodes.") |
|
|
chunk_nodes = [] |
|
|
for node in self: |
|
|
if isinstance(node, ChunkNode): |
|
|
chunk_nodes.append(node) |
|
|
return chunk_nodes |
|
|
|
|
|
def get_all_files(self) -> List[FileNode]: |
|
|
self.logger.debug("Getting all file nodes.") |
|
|
""" |
|
|
Get all FileNodes in the knowledge graph. |
|
|
|
|
|
Returns: |
|
|
List[FileNode]: A list of FileNodes in the graph. |
|
|
""" |
|
|
file_nodes = [] |
|
|
for node in self.graph.nodes(data=True): |
|
|
node_data = node[1]['data'] |
|
|
|
|
|
if isinstance(node_data, FileNode) and node_data.node_type == 'file': |
|
|
file_nodes.append(node_data) |
|
|
return file_nodes |
|
|
|
|
|
def get_chunks_of_file(self, file_node_id: str) -> List[ChunkNode]: |
|
|
self.logger.debug(f"Getting chunks for file node: {file_node_id}") |
|
|
""" |
|
|
Get all ChunkNodes associated with a specific FileNode. |
|
|
|
|
|
Args: |
|
|
file_node (FileNode): The file node to get chunks for. |
|
|
|
|
|
Returns: |
|
|
List[ChunkNode]: A list of ChunkNodes associated with the file. |
|
|
""" |
|
|
chunk_nodes = [] |
|
|
for node in self.graph.neighbors(file_node_id): |
|
|
|
|
|
edge_data = self.graph.get_edge_data(file_node_id, node) |
|
|
node_data = self.graph.nodes[node]['data'] |
|
|
if ( |
|
|
isinstance(node_data, ChunkNode) |
|
|
and node_data.node_type == 'chunk' |
|
|
and edge_data is not None |
|
|
and edge_data.get('relation') == 'contains' |
|
|
): |
|
|
chunk_nodes.append(node_data) |
|
|
return chunk_nodes |
|
|
|
|
|
def find_path(self, source_id: str, target_id: str, max_depth: int = 5) -> dict: |
|
|
""" |
|
|
Find the shortest path between two nodes in the knowledge graph. |
|
|
|
|
|
Args: |
|
|
source_id (str): The ID of the source node. |
|
|
target_id (str): The ID of the target node. |
|
|
max_depth (int): Maximum depth to search for a path. Defaults to 5. |
|
|
|
|
|
Returns: |
|
|
dict: A dictionary containing path information or error message. |
|
|
""" |
|
|
self.logger.debug(f"Finding path from {source_id} to {target_id} with max_depth={max_depth}") |
|
|
g = self.graph |
|
|
|
|
|
if source_id not in g: |
|
|
return {"error": f"Source node '{source_id}' not found."} |
|
|
if target_id not in g: |
|
|
return {"error": f"Target node '{target_id}' not found."} |
|
|
|
|
|
try: |
|
|
path = nx.shortest_path(g, source=source_id, target=target_id) |
|
|
|
|
|
if len(path) - 1 > max_depth: |
|
|
return { |
|
|
"source_id": source_id, |
|
|
"target_id": target_id, |
|
|
"path": [], |
|
|
"length": len(path) - 1, |
|
|
"text": f"Path exists but exceeds max_depth of {max_depth} (actual length: {len(path) - 1})" |
|
|
} |
|
|
|
|
|
|
|
|
path_details = [] |
|
|
for i, node_id in enumerate(path): |
|
|
node = g.nodes[node_id]['data'] |
|
|
node_info = { |
|
|
"node_id": node_id, |
|
|
"name": getattr(node, 'name', 'Unknown'), |
|
|
"type": getattr(node, 'node_type', 'Unknown'), |
|
|
"step": i |
|
|
} |
|
|
|
|
|
|
|
|
if i < len(path) - 1: |
|
|
next_node_id = path[i + 1] |
|
|
edge_data = g.get_edge_data(node_id, next_node_id) |
|
|
node_info["edge_to_next"] = edge_data.get('relation', 'Unknown') if edge_data else 'Unknown' |
|
|
|
|
|
path_details.append(node_info) |
|
|
|
|
|
|
|
|
text = f"Path from '{source_id}' to '{target_id}' (length: {len(path) - 1}):\n\n" |
|
|
for i, node_info in enumerate(path_details): |
|
|
text += f"{i}. {node_info['name']} ({node_info['type']})\n" |
|
|
text += f" Node ID: {node_info['node_id']}\n" |
|
|
if 'edge_to_next' in node_info: |
|
|
text += f" --[{node_info['edge_to_next']}]--> \n" |
|
|
|
|
|
return { |
|
|
"source_id": source_id, |
|
|
"target_id": target_id, |
|
|
"path": path_details, |
|
|
"length": len(path) - 1, |
|
|
"text": text |
|
|
} |
|
|
|
|
|
except nx.NetworkXNoPath: |
|
|
return { |
|
|
"source_id": source_id, |
|
|
"target_id": target_id, |
|
|
"path": [], |
|
|
"length": -1, |
|
|
"text": f"No path found between '{source_id}' and '{target_id}'" |
|
|
} |
|
|
except Exception as e: |
|
|
self.logger.error(f"Error finding path: {str(e)}") |
|
|
return {"error": f"Error finding path: {str(e)}"} |
|
|
|
|
|
def get_subgraph(self, node_id: str, depth: int = 2, edge_types: Optional[List[str]] = None) -> dict: |
|
|
""" |
|
|
Extract a subgraph around a node up to a specified depth. |
|
|
|
|
|
Args: |
|
|
node_id (str): The ID of the central node. |
|
|
depth (int): The depth/radius of the subgraph to extract. Defaults to 2. |
|
|
edge_types (Optional[List[str]]): Optional list of edge types to include (e.g., ['calls', 'contains']). |
|
|
|
|
|
Returns: |
|
|
dict: A dictionary containing subgraph information or error message. |
|
|
""" |
|
|
self.logger.debug(f"Getting subgraph for node {node_id} with depth={depth}, edge_types={edge_types}") |
|
|
g = self.graph |
|
|
|
|
|
if node_id not in g: |
|
|
return {"error": f"Node '{node_id}' not found."} |
|
|
|
|
|
|
|
|
nodes_at_depth = {node_id} |
|
|
all_nodes = {node_id} |
|
|
|
|
|
for d in range(depth): |
|
|
next_level = set() |
|
|
for n in nodes_at_depth: |
|
|
|
|
|
for neighbor in g.successors(n): |
|
|
if edge_types is None: |
|
|
next_level.add(neighbor) |
|
|
else: |
|
|
edge_data = g.get_edge_data(n, neighbor) |
|
|
if edge_data and edge_data.get('relation') in edge_types: |
|
|
next_level.add(neighbor) |
|
|
|
|
|
for neighbor in g.predecessors(n): |
|
|
if edge_types is None: |
|
|
next_level.add(neighbor) |
|
|
else: |
|
|
edge_data = g.get_edge_data(neighbor, n) |
|
|
if edge_data and edge_data.get('relation') in edge_types: |
|
|
next_level.add(neighbor) |
|
|
|
|
|
nodes_at_depth = next_level - all_nodes |
|
|
all_nodes.update(next_level) |
|
|
|
|
|
|
|
|
subgraph = g.subgraph(all_nodes).copy() |
|
|
|
|
|
|
|
|
nodes = [] |
|
|
for n in subgraph.nodes(): |
|
|
node = subgraph.nodes[n]['data'] |
|
|
nodes.append({ |
|
|
"node_id": n, |
|
|
"name": getattr(node, 'name', 'Unknown'), |
|
|
"type": getattr(node, 'node_type', 'Unknown') |
|
|
}) |
|
|
|
|
|
|
|
|
edges = [] |
|
|
for source, target, data in subgraph.edges(data=True): |
|
|
edges.append({ |
|
|
"source": source, |
|
|
"target": target, |
|
|
"relation": data.get('relation', 'Unknown') |
|
|
}) |
|
|
|
|
|
|
|
|
text = f"Subgraph around '{node_id}' (depth: {depth}):\n" |
|
|
if edge_types: |
|
|
text += f"Edge types filter: {', '.join(edge_types)}\n" |
|
|
text += f"\nNodes: {len(nodes)}\n" |
|
|
text += f"Edges: {len(edges)}\n\n" |
|
|
|
|
|
|
|
|
nodes_by_type = {} |
|
|
for node in nodes: |
|
|
node_type = node['type'] |
|
|
if node_type not in nodes_by_type: |
|
|
nodes_by_type[node_type] = [] |
|
|
nodes_by_type[node_type].append(node) |
|
|
|
|
|
for node_type, type_nodes in nodes_by_type.items(): |
|
|
text += f"{node_type} ({len(type_nodes)}):\n" |
|
|
for node in type_nodes[:5]: |
|
|
text += f" - {node['name']} ({node['node_id']})\n" |
|
|
if len(type_nodes) > 5: |
|
|
text += f" ... and {len(type_nodes) - 5} more\n" |
|
|
text += "\n" |
|
|
|
|
|
|
|
|
edge_by_relation = {} |
|
|
for edge in edges: |
|
|
relation = edge['relation'] |
|
|
edge_by_relation[relation] = edge_by_relation.get(relation, 0) + 1 |
|
|
|
|
|
if edge_by_relation: |
|
|
text += "Edge types:\n" |
|
|
for relation, count in edge_by_relation.items(): |
|
|
text += f" - {relation}: {count}\n" |
|
|
|
|
|
return { |
|
|
"center_node_id": node_id, |
|
|
"depth": depth, |
|
|
"edge_types_filter": edge_types, |
|
|
"node_count": len(nodes), |
|
|
"edge_count": len(edges), |
|
|
"nodes": nodes, |
|
|
"edges": edges, |
|
|
"nodes_by_type": nodes_by_type, |
|
|
"edge_by_relation": edge_by_relation, |
|
|
"text": text |
|
|
} |
|
|
|