upload repoknowledgegraphlib
Browse files- RepoKnowledgeGraphLib/CodeIndex.py +571 -0
- RepoKnowledgeGraphLib/CodeParser.py +70 -0
- RepoKnowledgeGraphLib/Entity.py +188 -0
- RepoKnowledgeGraphLib/EntityChunkMapper.py +517 -0
- RepoKnowledgeGraphLib/EntityExtractor.py +2032 -0
- RepoKnowledgeGraphLib/KnowledgeGraphMCPServer.py +1107 -0
- RepoKnowledgeGraphLib/ModelService.py +424 -0
- RepoKnowledgeGraphLib/Node.py +63 -0
- RepoKnowledgeGraphLib/QuestionMaker.py +538 -0
- RepoKnowledgeGraphLib/RepoKnowledgeGraph.py +1608 -0
- RepoKnowledgeGraphLib/__init__.py +5 -0
- RepoKnowledgeGraphLib/utils/__init__.py +0 -0
- RepoKnowledgeGraphLib/utils/chunk_utils.py +88 -0
- RepoKnowledgeGraphLib/utils/data_utils.py +18 -0
- RepoKnowledgeGraphLib/utils/logger_utils.py +74 -0
- RepoKnowledgeGraphLib/utils/parsing_utils.py +65 -0
- RepoKnowledgeGraphLib/utils/path_utils.py +308 -0
RepoKnowledgeGraphLib/CodeIndex.py
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
import uuid
|
| 4 |
+
from typing import Literal
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
import lancedb
|
| 7 |
+
import os
|
| 8 |
+
import numpy as np
|
| 9 |
+
import weaviate
|
| 10 |
+
from weaviate.classes.config import Configure, Property, DataType
|
| 11 |
+
from weaviate.classes.query import MetadataQuery
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
LANCEDB_AVAILABLE = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
LANCEDB_AVAILABLE = False
|
| 17 |
+
|
| 18 |
+
from .utils.logger_utils import setup_logger
|
| 19 |
+
|
| 20 |
+
LOGGER_NAME = 'CODE_INDEX_LOGGER'
|
| 21 |
+
STOP_AFTER_ATTEMPT = int(os.getenv("STOP_AFTER_ATTEMPT", 5))
|
| 22 |
+
WAIT_BETWEEN_RETRIES = int(os.getenv("WAIT_BETWEEN_RETRIES", 2))
|
| 23 |
+
MODEL_ID = os.getenv("MODEL_ID")
|
| 24 |
+
MAX_TOKENS = int(os.getenv('MAX_TOKENS', 2048))
|
| 25 |
+
TEMPERATURE = float(os.getenv('TEMPERATURE', 0.2))
|
| 26 |
+
TOP_P = float(os.getenv('TOP_P', 0.95))
|
| 27 |
+
FREQUENCY_PENALTY = 0
|
| 28 |
+
PRESENCE_PENALTY = 0
|
| 29 |
+
STOP = None
|
| 30 |
+
EMBEDDING_MODEL_URL = os.getenv('EMBEDDING_MODEL_URL')
|
| 31 |
+
EMBEDDING_MODEL_API_KEY = os.getenv('EMBEDDING_MODEL_API_KEY', "no_need")
|
| 32 |
+
EMBEDDING_NUMBER_DIMENSIONS = int(os.getenv('EMBEDDING_NUMBER_DIMENSIONS', 1024))
|
| 33 |
+
|
| 34 |
+
WEAVIATE_HOST = os.getenv('WEAVIATE_HOST', "localhost")
|
| 35 |
+
WEAVIATE_PORT = int(os.getenv('WEAVIATE_PORT', 8080))
|
| 36 |
+
WEAVIATE_GRPC_PORT = int(os.getenv('WEAVIATE_GRPC_PORT', 50051))
|
| 37 |
+
ALPHA_SEARCH_VALUE = float(os.getenv('ALPHA_SEARCH_VALUE', 0.8))
|
| 38 |
+
LANCEDB_PATH = os.getenv('LANCEDB_PATH', './local_code_index_db')
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class BaseCodeIndex(ABC):
|
| 42 |
+
"""Abstract base class for code indexing implementations"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, nodes: list, model_service, index_type: Literal['embedding-only', 'keyword-only', 'hybrid'] = 'hybrid',
|
| 45 |
+
embedding_batch_size: int = 64, use_embed: bool = True):
|
| 46 |
+
setup_logger(LOGGER_NAME)
|
| 47 |
+
self.logger = logging.getLogger(LOGGER_NAME)
|
| 48 |
+
self.model_service = model_service
|
| 49 |
+
self.index_type = index_type
|
| 50 |
+
# Use larger batch size by default for better throughput
|
| 51 |
+
self.embedding_batch_size = int(os.getenv('EMBEDDING_BATCH_SIZE', embedding_batch_size))
|
| 52 |
+
self.use_embed = use_embed
|
| 53 |
+
self.logger.info(f"CodeIndex initialized with batch_size={self.embedding_batch_size}, index_type={index_type}")
|
| 54 |
+
|
| 55 |
+
@abstractmethod
|
| 56 |
+
def query(self, query: str, n_results: int=10) -> dict:
|
| 57 |
+
"""Query the index and return results"""
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
@abstractmethod
|
| 61 |
+
def __del__(self):
|
| 62 |
+
"""Clean up resources"""
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class WeaviateCodeIndex(BaseCodeIndex):
|
| 67 |
+
"""Weaviate-based code index implementation"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, nodes: list, model_service, index_type: Literal['embedding-only', 'keyword-only', 'hybrid'] = 'hybrid',
|
| 70 |
+
embedding_batch_size: int = 20, use_embed: bool = True,
|
| 71 |
+
host: str = None, port: int = None, grpc_port: int = None):
|
| 72 |
+
super().__init__(nodes, model_service, index_type, embedding_batch_size, use_embed)
|
| 73 |
+
|
| 74 |
+
# Use provided parameters or fall back to environment variables
|
| 75 |
+
weaviate_host = host or WEAVIATE_HOST
|
| 76 |
+
weaviate_port = port or WEAVIATE_PORT
|
| 77 |
+
weaviate_grpc_port = grpc_port or WEAVIATE_GRPC_PORT
|
| 78 |
+
|
| 79 |
+
# Connect to Weaviate
|
| 80 |
+
self.weaviate_client = weaviate.connect_to_local(
|
| 81 |
+
host=weaviate_host,
|
| 82 |
+
port=weaviate_port,
|
| 83 |
+
grpc_port=weaviate_grpc_port
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Create a unique collection name
|
| 87 |
+
self.collection_name = f"CodeChunks_{str(uuid.uuid4()).replace('-', '_')}"
|
| 88 |
+
|
| 89 |
+
# Create collection with schema using the v4 API
|
| 90 |
+
# Use vector_config with Configure.Vectors.self_provided() - the modern approach
|
| 91 |
+
self.collection = self.weaviate_client.collections.create(
|
| 92 |
+
name=self.collection_name,
|
| 93 |
+
properties=[
|
| 94 |
+
Property(name="node_id", data_type=DataType.TEXT),
|
| 95 |
+
Property(name="name", data_type=DataType.TEXT),
|
| 96 |
+
Property(name="content", data_type=DataType.TEXT),
|
| 97 |
+
Property(name="description", data_type=DataType.TEXT),
|
| 98 |
+
Property(name="path", data_type=DataType.TEXT),
|
| 99 |
+
Property(name="language", data_type=DataType.TEXT),
|
| 100 |
+
Property(name="node_type", data_type=DataType.TEXT),
|
| 101 |
+
Property(name="order_in_file", data_type=DataType.INT),
|
| 102 |
+
Property(name="declared_entities", data_type=DataType.TEXT),
|
| 103 |
+
Property(name="called_entities", data_type=DataType.TEXT),
|
| 104 |
+
],
|
| 105 |
+
# We provide our own vectors using the modern vector_config API
|
| 106 |
+
vector_config=Configure.Vectors.self_provided(),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
chunk_nodes = [node for node in nodes if node.node_type == 'chunk']
|
| 110 |
+
self.logger.info(f"Weaviate indexing {len(chunk_nodes)} chunk nodes with batch_size={self.embedding_batch_size}")
|
| 111 |
+
|
| 112 |
+
# Pre-generate embeddings in batches for better performance
|
| 113 |
+
if self.index_type != 'keyword-only':
|
| 114 |
+
# Identify nodes that need embeddings
|
| 115 |
+
nodes_needing_embeddings = [
|
| 116 |
+
node for node in chunk_nodes
|
| 117 |
+
if node.embedding is None or (isinstance(node.embedding, (list,)) and len(node.embedding) == 0) or not use_embed
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
if nodes_needing_embeddings:
|
| 121 |
+
total_batches = (len(nodes_needing_embeddings) + self.embedding_batch_size - 1) // self.embedding_batch_size
|
| 122 |
+
self.logger.info(f'Batch embedding {len(nodes_needing_embeddings)} nodes in {total_batches} batches')
|
| 123 |
+
|
| 124 |
+
# Process in batches
|
| 125 |
+
for i in tqdm(range(0, len(nodes_needing_embeddings), self.embedding_batch_size),
|
| 126 |
+
desc="Batch embedding nodes"):
|
| 127 |
+
batch_nodes = nodes_needing_embeddings[i:i + self.embedding_batch_size]
|
| 128 |
+
texts_to_embed = [node.get_field_to_embed() for node in batch_nodes]
|
| 129 |
+
|
| 130 |
+
# Batch embed all texts
|
| 131 |
+
embeddings = self.model_service.embed_chunk_code_batch(texts_to_embed)
|
| 132 |
+
|
| 133 |
+
# Assign embeddings back to nodes
|
| 134 |
+
for node, embedding in zip(batch_nodes, embeddings):
|
| 135 |
+
node.embedding = embedding
|
| 136 |
+
|
| 137 |
+
# Log progress every 10 batches
|
| 138 |
+
batch_num = i // self.embedding_batch_size + 1
|
| 139 |
+
if batch_num % 10 == 0:
|
| 140 |
+
self.logger.info(f"Completed batch {batch_num}/{total_batches}")
|
| 141 |
+
|
| 142 |
+
self.logger.info(f"Embedding complete: processed {len(nodes_needing_embeddings)} nodes")
|
| 143 |
+
else:
|
| 144 |
+
self.logger.info(f"Using existing embeddings for all {len(chunk_nodes)} nodes")
|
| 145 |
+
|
| 146 |
+
# Batch insert data into Weaviate
|
| 147 |
+
with self.collection.batch.dynamic() as batch:
|
| 148 |
+
for node in tqdm(chunk_nodes, desc="Indexing nodes"):
|
| 149 |
+
self.logger.debug(f'Indexing node : {node.id}')
|
| 150 |
+
|
| 151 |
+
# Use pre-computed embedding
|
| 152 |
+
embedding = None
|
| 153 |
+
if self.index_type != 'keyword-only':
|
| 154 |
+
embedding = node.embedding
|
| 155 |
+
|
| 156 |
+
# Prepare properties
|
| 157 |
+
properties = {
|
| 158 |
+
"node_id": node.id,
|
| 159 |
+
"name": node.name,
|
| 160 |
+
"content": node.content,
|
| 161 |
+
"description": node.description or "",
|
| 162 |
+
"path": node.path,
|
| 163 |
+
"language": node.language,
|
| 164 |
+
"node_type": node.node_type,
|
| 165 |
+
"order_in_file": node.order_in_file,
|
| 166 |
+
"declared_entities": str(node.declared_entities),
|
| 167 |
+
"called_entities": str(node.called_entities),
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
# Add object with or without vector based on index_type
|
| 171 |
+
if self.index_type == 'keyword-only':
|
| 172 |
+
# No vector needed for keyword-only search
|
| 173 |
+
batch.add_object(properties=properties)
|
| 174 |
+
else:
|
| 175 |
+
# Add with vector for embedding-only and hybrid modes
|
| 176 |
+
batch.add_object(
|
| 177 |
+
properties=properties,
|
| 178 |
+
vector=embedding
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def query(self, query: str, n_results:int=10) -> dict:
|
| 183 |
+
"""
|
| 184 |
+
Perform search based on index_type:
|
| 185 |
+
- 'embedding-only': pure vector search
|
| 186 |
+
- 'keyword-only': pure keyword search (BM25)
|
| 187 |
+
- 'hybrid': hybrid search combining both (alpha controls weighting)
|
| 188 |
+
|
| 189 |
+
Weaviate's hybrid search uses:
|
| 190 |
+
- alpha=0: pure keyword search (BM25)
|
| 191 |
+
- alpha=1: pure vector search
|
| 192 |
+
- alpha=0.5-0.8: balanced hybrid search (recommended)
|
| 193 |
+
"""
|
| 194 |
+
try:
|
| 195 |
+
# Execute search based on index_type
|
| 196 |
+
if self.index_type == 'keyword-only':
|
| 197 |
+
# Pure BM25 keyword search
|
| 198 |
+
response = self.collection.query.bm25(
|
| 199 |
+
query=query,
|
| 200 |
+
limit=n_results,
|
| 201 |
+
return_metadata=MetadataQuery(score=True)
|
| 202 |
+
)
|
| 203 |
+
elif self.index_type == 'embedding-only':
|
| 204 |
+
# Pure vector search
|
| 205 |
+
embedding = self.model_service.embed_query(query)
|
| 206 |
+
response = self.collection.query.near_vector(
|
| 207 |
+
near_vector=embedding,
|
| 208 |
+
limit=n_results,
|
| 209 |
+
return_metadata=MetadataQuery(distance=True)
|
| 210 |
+
)
|
| 211 |
+
else: # 'hybrid'
|
| 212 |
+
# Hybrid search combining keyword and vector
|
| 213 |
+
embedding = self.model_service.embed_query(query)
|
| 214 |
+
response = self.collection.query.hybrid(
|
| 215 |
+
query=query,
|
| 216 |
+
vector=embedding,
|
| 217 |
+
limit=n_results,
|
| 218 |
+
alpha=ALPHA_SEARCH_VALUE,
|
| 219 |
+
return_metadata=MetadataQuery(distance=True, score=True)
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Convert to ChromaDB-like format for compatibility
|
| 223 |
+
results = {
|
| 224 |
+
'ids': [[]],
|
| 225 |
+
'distances': [[]],
|
| 226 |
+
'metadatas': [[]],
|
| 227 |
+
'documents': [[]]
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
for obj in response.objects:
|
| 231 |
+
results['ids'][0].append(obj.properties['node_id'])
|
| 232 |
+
results['distances'][0].append(obj.metadata.distance if obj.metadata.distance else 0.0)
|
| 233 |
+
results['metadatas'][0].append({
|
| 234 |
+
'id': obj.properties['node_id'],
|
| 235 |
+
'name': obj.properties['name'],
|
| 236 |
+
'content': obj.properties['content'],
|
| 237 |
+
'description': obj.properties['description'],
|
| 238 |
+
'path': obj.properties['path'],
|
| 239 |
+
'language': obj.properties['language'],
|
| 240 |
+
'node_type': obj.properties['node_type'],
|
| 241 |
+
'order_in_file': str(obj.properties['order_in_file']),
|
| 242 |
+
'declared_entities': obj.properties['declared_entities'],
|
| 243 |
+
'called_entities': obj.properties['called_entities'],
|
| 244 |
+
})
|
| 245 |
+
results['documents'][0].append(obj.properties['content'])
|
| 246 |
+
|
| 247 |
+
return results
|
| 248 |
+
|
| 249 |
+
except Exception as e:
|
| 250 |
+
self.logger.error(f'Failed to query: {e}', exc_info=True)
|
| 251 |
+
raise e
|
| 252 |
+
|
| 253 |
+
def __del__(self):
|
| 254 |
+
"""Clean up Weaviate connection"""
|
| 255 |
+
if hasattr(self, 'weaviate_client'):
|
| 256 |
+
try:
|
| 257 |
+
self.weaviate_client.close()
|
| 258 |
+
except:
|
| 259 |
+
pass
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class LanceDBCodeIndex(BaseCodeIndex):
|
| 263 |
+
"""LanceDB-based code index implementation"""
|
| 264 |
+
|
| 265 |
+
def __init__(self, nodes: list, model_service, index_type: Literal['embedding-only', 'keyword-only', 'hybrid'] = 'hybrid',
|
| 266 |
+
embedding_batch_size: int = 20, use_embed: bool = True, db_path: str = None):
|
| 267 |
+
super().__init__(nodes, model_service, index_type, embedding_batch_size, use_embed)
|
| 268 |
+
|
| 269 |
+
if not LANCEDB_AVAILABLE:
|
| 270 |
+
raise ImportError("LanceDB is not available. Please install it with: pip install lancedb")
|
| 271 |
+
|
| 272 |
+
# Embedded DB
|
| 273 |
+
self.db_path = db_path or LANCEDB_PATH
|
| 274 |
+
self.db = lancedb.connect(self.db_path)
|
| 275 |
+
self.table_name = f"code_chunks_{uuid.uuid4().hex}"
|
| 276 |
+
self.table = None
|
| 277 |
+
|
| 278 |
+
chunk_nodes = [node for node in nodes if node.node_type == "chunk"]
|
| 279 |
+
self.logger.info(f"LanceDB indexing {len(chunk_nodes)} chunk nodes with batch_size={self.embedding_batch_size}")
|
| 280 |
+
|
| 281 |
+
# -----------------------------------------------------------
|
| 282 |
+
# Create embeddings IF using vector search
|
| 283 |
+
# -----------------------------------------------------------
|
| 284 |
+
if self.index_type != "keyword-only":
|
| 285 |
+
# Find nodes that need embeddings
|
| 286 |
+
# use_embed=True means we should USE existing embeddings if available
|
| 287 |
+
# use_embed=False means we should regenerate all embeddings
|
| 288 |
+
nodes_needing_embeddings = []
|
| 289 |
+
for node in chunk_nodes:
|
| 290 |
+
needs_embedding = False
|
| 291 |
+
if not use_embed:
|
| 292 |
+
# Regenerate all embeddings
|
| 293 |
+
needs_embedding = True
|
| 294 |
+
elif node.embedding is None:
|
| 295 |
+
needs_embedding = True
|
| 296 |
+
elif isinstance(node.embedding, (list, np.ndarray)) and len(node.embedding) == 0:
|
| 297 |
+
needs_embedding = True
|
| 298 |
+
|
| 299 |
+
if needs_embedding:
|
| 300 |
+
nodes_needing_embeddings.append(node)
|
| 301 |
+
|
| 302 |
+
if nodes_needing_embeddings:
|
| 303 |
+
total_batches = (len(nodes_needing_embeddings) + self.embedding_batch_size - 1) // self.embedding_batch_size
|
| 304 |
+
self.logger.info(f"Embedding {len(nodes_needing_embeddings)} chunks in {total_batches} batches (batch_size={self.embedding_batch_size})...")
|
| 305 |
+
|
| 306 |
+
for i in tqdm(range(0, len(nodes_needing_embeddings), self.embedding_batch_size),
|
| 307 |
+
desc="Batch embedding nodes"):
|
| 308 |
+
batch = nodes_needing_embeddings[i:i + self.embedding_batch_size]
|
| 309 |
+
texts = [n.get_field_to_embed() for n in batch]
|
| 310 |
+
batch_embeds = self.model_service.embed_chunk_code_batch(texts)
|
| 311 |
+
|
| 312 |
+
for n, emb in zip(batch, batch_embeds):
|
| 313 |
+
n.embedding = np.array(emb, dtype=np.float32)
|
| 314 |
+
|
| 315 |
+
# Log progress every 10 batches
|
| 316 |
+
batch_num = i // self.embedding_batch_size + 1
|
| 317 |
+
if batch_num % 10 == 0:
|
| 318 |
+
self.logger.info(f"Completed batch {batch_num}/{total_batches}")
|
| 319 |
+
|
| 320 |
+
self.logger.info(f"Embedding complete: processed {len(nodes_needing_embeddings)} chunks")
|
| 321 |
+
else:
|
| 322 |
+
self.logger.info(f"Using existing embeddings for all {len(chunk_nodes)} chunks")
|
| 323 |
+
|
| 324 |
+
# -----------------------------------------------------------
|
| 325 |
+
# Prepare rows (only include vector column when allowed)
|
| 326 |
+
# -----------------------------------------------------------
|
| 327 |
+
rows = []
|
| 328 |
+
for node in chunk_nodes:
|
| 329 |
+
row = {
|
| 330 |
+
"node_id": node.id,
|
| 331 |
+
"name": node.name,
|
| 332 |
+
"content": node.content,
|
| 333 |
+
"description": node.description or "",
|
| 334 |
+
"path": node.path,
|
| 335 |
+
"language": node.language,
|
| 336 |
+
"node_type": node.node_type,
|
| 337 |
+
"order_in_file": node.order_in_file,
|
| 338 |
+
"declared_entities": str(node.declared_entities),
|
| 339 |
+
"called_entities": str(node.called_entities),
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
# Add embeddings only for hybrid/embedding-only
|
| 343 |
+
if self.index_type != "keyword-only":
|
| 344 |
+
row["vector"] = node.embedding
|
| 345 |
+
|
| 346 |
+
rows.append(row)
|
| 347 |
+
|
| 348 |
+
# Create table
|
| 349 |
+
self.table = self.db.create_table(self.table_name, data=rows)
|
| 350 |
+
self.logger.info(f"Created LanceDB table: {self.table_name}")
|
| 351 |
+
|
| 352 |
+
# Create full-text search index for keyword and hybrid search
|
| 353 |
+
# LanceDB requires INVERTED index for full-text search
|
| 354 |
+
self._create_fts_indexes()
|
| 355 |
+
|
| 356 |
+
def _create_fts_indexes(self):
|
| 357 |
+
"""
|
| 358 |
+
Create full-text search indexes on text columns.
|
| 359 |
+
|
| 360 |
+
LanceDB 0.25.x uses create_fts_index() with use_tantivy=True to support
|
| 361 |
+
multiple columns. Requires tantivy package: pip install tantivy
|
| 362 |
+
"""
|
| 363 |
+
fts_columns = ["content", "name", "description"]
|
| 364 |
+
|
| 365 |
+
try:
|
| 366 |
+
# use_tantivy=True is required to support multiple field names as a list
|
| 367 |
+
self.table.create_fts_index(fts_columns, replace=True, use_tantivy=True)
|
| 368 |
+
self.logger.info(f"Created FTS index (Tantivy) on columns: {fts_columns}")
|
| 369 |
+
except Exception as e:
|
| 370 |
+
self.logger.warning(f"Failed to create FTS index: {e}")
|
| 371 |
+
self.logger.warning(
|
| 372 |
+
"Full-text search will fall back to scanning. "
|
| 373 |
+
"Ensure tantivy is installed: pip install tantivy"
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
def query(self, query: str, n_results: int=10) -> dict:
|
| 377 |
+
"""
|
| 378 |
+
Perform search based on index_type:
|
| 379 |
+
- 'embedding-only': pure vector search
|
| 380 |
+
- 'keyword-only': full-text search using LanceDB's native FTS
|
| 381 |
+
- 'hybrid': combines vector similarity and full-text search with reranking
|
| 382 |
+
"""
|
| 383 |
+
try:
|
| 384 |
+
# ---------------------- KEYWORD ONLY ----------------------
|
| 385 |
+
if self.index_type == "keyword-only":
|
| 386 |
+
# Use LanceDB full-text search (requires FTS index on the table)
|
| 387 |
+
try:
|
| 388 |
+
# Try full-text search first
|
| 389 |
+
df = self.table.search(query, query_type="fts").limit(n_results).to_pandas()
|
| 390 |
+
except Exception as fts_error:
|
| 391 |
+
self.logger.warning(f"FTS search failed, falling back to scan: {fts_error}")
|
| 392 |
+
# Fallback: scan all rows and filter in Python
|
| 393 |
+
all_df = self.table.to_pandas()
|
| 394 |
+
query_lower = query.lower()
|
| 395 |
+
# Split query into words for more flexible matching
|
| 396 |
+
query_words = query_lower.split()
|
| 397 |
+
|
| 398 |
+
def matches_query(row):
|
| 399 |
+
text = f"{row.get('content', '')} {row.get('name', '')} {row.get('description', '')}".lower()
|
| 400 |
+
# Match if any query word is found
|
| 401 |
+
return any(word in text for word in query_words)
|
| 402 |
+
|
| 403 |
+
mask = all_df.apply(matches_query, axis=1)
|
| 404 |
+
df = all_df[mask].head(n_results)
|
| 405 |
+
# Add a dummy distance column
|
| 406 |
+
df = df.copy()
|
| 407 |
+
df['_distance'] = 0.0
|
| 408 |
+
|
| 409 |
+
# ---------------------- VECTOR ONLY -----------------------
|
| 410 |
+
elif self.index_type == "embedding-only":
|
| 411 |
+
emb = np.array(self.model_service.embed_query(query), dtype=np.float32)
|
| 412 |
+
df = self.table.search(
|
| 413 |
+
emb,
|
| 414 |
+
vector_column_name="vector"
|
| 415 |
+
).limit(n_results).to_pandas()
|
| 416 |
+
|
| 417 |
+
# ---------------------- HYBRID ----------------------------
|
| 418 |
+
else:
|
| 419 |
+
# For hybrid search, we do vector search and optionally boost results
|
| 420 |
+
# that also match keywords. This is more flexible than requiring both.
|
| 421 |
+
emb = np.array(self.model_service.embed_query(query), dtype=np.float32)
|
| 422 |
+
|
| 423 |
+
# Get more results from vector search to allow for reranking
|
| 424 |
+
vector_limit = min(n_results * 3, 100) # Get 3x results for reranking
|
| 425 |
+
df = self.table.search(
|
| 426 |
+
emb,
|
| 427 |
+
vector_column_name="vector"
|
| 428 |
+
).limit(vector_limit).to_pandas()
|
| 429 |
+
|
| 430 |
+
if not df.empty:
|
| 431 |
+
# Rerank results based on keyword matches
|
| 432 |
+
query_lower = query.lower()
|
| 433 |
+
query_words = query_lower.split()
|
| 434 |
+
|
| 435 |
+
def compute_keyword_score(row):
|
| 436 |
+
"""Compute a keyword match score (higher is better)"""
|
| 437 |
+
text = f"{row.get('content', '')} {row.get('name', '')} {row.get('description', '')}".lower()
|
| 438 |
+
score = 0
|
| 439 |
+
# Exact phrase match gets highest score
|
| 440 |
+
if query_lower in text:
|
| 441 |
+
score += 10
|
| 442 |
+
# Word matches
|
| 443 |
+
for word in query_words:
|
| 444 |
+
if word in text:
|
| 445 |
+
score += 1
|
| 446 |
+
# Bonus for word in name (more relevant)
|
| 447 |
+
if word in str(row.get('name', '')).lower():
|
| 448 |
+
score += 2
|
| 449 |
+
return score
|
| 450 |
+
|
| 451 |
+
# Add keyword scores
|
| 452 |
+
df = df.copy()
|
| 453 |
+
df['_keyword_score'] = df.apply(compute_keyword_score, axis=1)
|
| 454 |
+
|
| 455 |
+
# Normalize distance to a similarity score (lower distance = higher similarity)
|
| 456 |
+
max_dist = df['_distance'].max() if df['_distance'].max() > 0 else 1.0
|
| 457 |
+
df['_vector_score'] = 1.0 - (df['_distance'] / max_dist)
|
| 458 |
+
|
| 459 |
+
# Combined score: weighted sum of vector similarity and keyword score
|
| 460 |
+
# Alpha controls the balance (higher alpha = more weight on vector search)
|
| 461 |
+
alpha = 0.7 # 70% vector, 30% keyword
|
| 462 |
+
max_keyword = df['_keyword_score'].max() if df['_keyword_score'].max() > 0 else 1.0
|
| 463 |
+
df['_combined_score'] = (
|
| 464 |
+
alpha * df['_vector_score'] +
|
| 465 |
+
(1 - alpha) * (df['_keyword_score'] / max_keyword)
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
# Sort by combined score (descending) and take top n_results
|
| 469 |
+
df = df.sort_values('_combined_score', ascending=False).head(n_results)
|
| 470 |
+
|
| 471 |
+
# Build result format (ChromaDB-like format for compatibility)
|
| 472 |
+
results = {
|
| 473 |
+
"ids": [[]],
|
| 474 |
+
"distances": [[]],
|
| 475 |
+
"metadatas": [[]],
|
| 476 |
+
"documents": [[]],
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
for _, row in df.iterrows():
|
| 480 |
+
results["ids"][0].append(row["node_id"])
|
| 481 |
+
results["documents"][0].append(row["content"])
|
| 482 |
+
results["distances"][0].append(float(row.get("_distance", 0)))
|
| 483 |
+
|
| 484 |
+
results["metadatas"][0].append({
|
| 485 |
+
"id": row["node_id"],
|
| 486 |
+
"name": row["name"],
|
| 487 |
+
"content": row["content"],
|
| 488 |
+
"description": row["description"],
|
| 489 |
+
"path": row["path"],
|
| 490 |
+
"language": row["language"],
|
| 491 |
+
"node_type": row["node_type"],
|
| 492 |
+
"order_in_file": str(row["order_in_file"]),
|
| 493 |
+
"declared_entities": row["declared_entities"],
|
| 494 |
+
"called_entities": row["called_entities"],
|
| 495 |
+
})
|
| 496 |
+
|
| 497 |
+
return results
|
| 498 |
+
|
| 499 |
+
except Exception as e:
|
| 500 |
+
self.logger.error(f"Query failed: {e}", exc_info=True)
|
| 501 |
+
raise
|
| 502 |
+
|
| 503 |
+
def __del__(self):
|
| 504 |
+
"""Clean up resources"""
|
| 505 |
+
pass
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
# Factory function to create the appropriate CodeIndex
|
| 509 |
+
def CodeIndex(
|
| 510 |
+
nodes: list,
|
| 511 |
+
model_service,
|
| 512 |
+
index_type: Literal['embedding-only', 'keyword-only', 'hybrid'] = 'hybrid',
|
| 513 |
+
embedding_batch_size: int = 20,
|
| 514 |
+
use_embed: bool = True,
|
| 515 |
+
backend: Literal['weaviate', 'lancedb'] = 'weaviate',
|
| 516 |
+
db_path: str = None,
|
| 517 |
+
host: str = None,
|
| 518 |
+
port: int = None,
|
| 519 |
+
grpc_port: int = None
|
| 520 |
+
) -> BaseCodeIndex:
|
| 521 |
+
"""
|
| 522 |
+
Factory function to create a CodeIndex instance.
|
| 523 |
+
|
| 524 |
+
Args:
|
| 525 |
+
nodes: List of nodes to index
|
| 526 |
+
model_service: Service for embedding generation
|
| 527 |
+
index_type: Type of search ('embedding-only', 'keyword-only', or 'hybrid')
|
| 528 |
+
embedding_batch_size: Batch size for embedding generation
|
| 529 |
+
use_embed: Whether to use pre-computed embeddings
|
| 530 |
+
backend: Which backend to use ('weaviate' or 'lancedb')
|
| 531 |
+
db_path: Path for LanceDB (only used with 'lancedb' backend)
|
| 532 |
+
host: Weaviate host (only used with 'weaviate' backend)
|
| 533 |
+
port: Weaviate port (only used with 'weaviate' backend)
|
| 534 |
+
grpc_port: Weaviate gRPC port (only used with 'weaviate' backend)
|
| 535 |
+
|
| 536 |
+
Returns:
|
| 537 |
+
BaseCodeIndex: Either WeaviateCodeIndex or LanceDBCodeIndex instance
|
| 538 |
+
"""
|
| 539 |
+
if backend == 'lancedb':
|
| 540 |
+
return LanceDBCodeIndex(
|
| 541 |
+
nodes=nodes,
|
| 542 |
+
model_service=model_service,
|
| 543 |
+
index_type=index_type,
|
| 544 |
+
embedding_batch_size=embedding_batch_size,
|
| 545 |
+
use_embed=use_embed,
|
| 546 |
+
db_path=db_path
|
| 547 |
+
)
|
| 548 |
+
elif backend == 'weaviate':
|
| 549 |
+
return WeaviateCodeIndex(
|
| 550 |
+
nodes=nodes,
|
| 551 |
+
model_service=model_service,
|
| 552 |
+
index_type=index_type,
|
| 553 |
+
embedding_batch_size=embedding_batch_size,
|
| 554 |
+
use_embed=use_embed,
|
| 555 |
+
host=host,
|
| 556 |
+
port=port,
|
| 557 |
+
grpc_port=grpc_port
|
| 558 |
+
)
|
| 559 |
+
else: # default to weaviate
|
| 560 |
+
return WeaviateCodeIndex(
|
| 561 |
+
nodes=nodes,
|
| 562 |
+
model_service=model_service,
|
| 563 |
+
index_type=index_type,
|
| 564 |
+
embedding_batch_size=embedding_batch_size,
|
| 565 |
+
use_embed=use_embed,
|
| 566 |
+
host=host,
|
| 567 |
+
port=port,
|
| 568 |
+
grpc_port=grpc_port
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
|
RepoKnowledgeGraphLib/CodeParser.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from langchain_text_splitters import (
|
| 5 |
+
Language,
|
| 6 |
+
RecursiveCharacterTextSplitter,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
from .utils.logger_utils import setup_logger
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
LOGGER_NAME = 'CODE_PARSER_LOGGER'
|
| 14 |
+
CODE_CHUNK_OVERLAP = int(os.getenv('CODE_CHUNK_OVERLAP', 0))
|
| 15 |
+
CODE_CHUNK_SIZE = int(os.getenv('CODE_CHUNK_SIZE', 2000))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class CodeParser:
|
| 19 |
+
def __init__(self):
|
| 20 |
+
setup_logger(LOGGER_NAME)
|
| 21 |
+
self.logger = logging.getLogger(LOGGER_NAME)
|
| 22 |
+
|
| 23 |
+
self.extension_mapping = {
|
| 24 |
+
'c': Language.C,
|
| 25 |
+
'h': Language.C,
|
| 26 |
+
'cpp': Language.CPP,
|
| 27 |
+
'cc': Language.CPP,
|
| 28 |
+
'cxx': Language.CPP,
|
| 29 |
+
'hpp': Language.CPP,
|
| 30 |
+
'hh': Language.CPP,
|
| 31 |
+
'hxx': Language.CPP,
|
| 32 |
+
'go': Language.GO,
|
| 33 |
+
'java': Language.JAVA,
|
| 34 |
+
'py': Language.PYTHON,
|
| 35 |
+
'pyw': Language.PYTHON,
|
| 36 |
+
'js': Language.JS,
|
| 37 |
+
'mjs': Language.JS,
|
| 38 |
+
'cjs': Language.JS,
|
| 39 |
+
'md': Language.MARKDOWN,
|
| 40 |
+
'markdown': Language.MARKDOWN,
|
| 41 |
+
'html': Language.HTML,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
def parse(self, file_name:str, file_content:str) -> list:
|
| 45 |
+
file_extension = file_name.split('.')[-1]
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
self.logger.debug(f'Parsing file: {file_name}')
|
| 49 |
+
if file_extension not in self.extension_mapping:
|
| 50 |
+
self.logger.debug(f'File extension not supported: {file_extension}')
|
| 51 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 52 |
+
chunk_size=CODE_CHUNK_SIZE,
|
| 53 |
+
chunk_overlap=CODE_CHUNK_OVERLAP,
|
| 54 |
+
length_function=len,
|
| 55 |
+
is_separator_regex=False,
|
| 56 |
+
)
|
| 57 |
+
docs = text_splitter.create_documents([file_content])
|
| 58 |
+
|
| 59 |
+
else:
|
| 60 |
+
self.logger.debug(f'File extension supported: {file_extension}')
|
| 61 |
+
code_splitter = RecursiveCharacterTextSplitter.from_language(language=self.extension_mapping[file_extension], chunk_size=CODE_CHUNK_SIZE, chunk_overlap=CODE_CHUNK_OVERLAP)
|
| 62 |
+
docs = code_splitter.create_documents([file_content])
|
| 63 |
+
except Exception as e:
|
| 64 |
+
self.logger.error(f'Error when parsing code: {e}')
|
| 65 |
+
return [doc.page_content for doc in docs]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
RepoKnowledgeGraphLib/Entity.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Dict, List, Type, Any
|
| 2 |
+
from dataclasses import dataclass, field, asdict, fields, is_dataclass
|
| 3 |
+
|
| 4 |
+
# Helper for dynamic class lookup
|
| 5 |
+
ENTITY_TYPE_MAP = {}
|
| 6 |
+
|
| 7 |
+
def register_entity(cls):
|
| 8 |
+
ENTITY_TYPE_MAP[cls.__name__] = cls
|
| 9 |
+
return cls
|
| 10 |
+
|
| 11 |
+
def _entity_to_dict(obj):
|
| 12 |
+
if isinstance(obj, list):
|
| 13 |
+
return [_entity_to_dict(item) for item in obj]
|
| 14 |
+
elif isinstance(obj, dict):
|
| 15 |
+
return {(_entity_to_dict(k) if isinstance(k, Entity) else k): _entity_to_dict(v) for k, v in obj.items()}
|
| 16 |
+
elif isinstance(obj, Entity):
|
| 17 |
+
return obj.to_dict()
|
| 18 |
+
elif hasattr(obj, 'to_dict'):
|
| 19 |
+
return obj.to_dict()
|
| 20 |
+
else:
|
| 21 |
+
return obj
|
| 22 |
+
|
| 23 |
+
def _entity_from_dict(data):
|
| 24 |
+
if isinstance(data, list):
|
| 25 |
+
return [_entity_from_dict(item) for item in data]
|
| 26 |
+
elif isinstance(data, dict) and 'entity_type' in data:
|
| 27 |
+
cls = ENTITY_TYPE_MAP.get(data['entity_type'].capitalize(), Entity)
|
| 28 |
+
return cls.from_dict(data)
|
| 29 |
+
else:
|
| 30 |
+
return data
|
| 31 |
+
|
| 32 |
+
@register_entity
|
| 33 |
+
@dataclass
|
| 34 |
+
class Entity:
|
| 35 |
+
entity_type: str
|
| 36 |
+
entity_name: str
|
| 37 |
+
defined_chunk_id: str
|
| 38 |
+
entity_dtype: str
|
| 39 |
+
|
| 40 |
+
def to_dict(self):
|
| 41 |
+
d = asdict(self)
|
| 42 |
+
d['entity_type'] = self.entity_type
|
| 43 |
+
d['__class__'] = self.__class__.__name__
|
| 44 |
+
return d
|
| 45 |
+
|
| 46 |
+
@classmethod
|
| 47 |
+
def from_dict(cls, data):
|
| 48 |
+
# Remove __class__ if present
|
| 49 |
+
data = dict(data)
|
| 50 |
+
data.pop('__class__', None)
|
| 51 |
+
return cls(**data)
|
| 52 |
+
|
| 53 |
+
@register_entity
|
| 54 |
+
@dataclass
|
| 55 |
+
class Variable(Entity):
|
| 56 |
+
entity_type = 'variable'
|
| 57 |
+
|
| 58 |
+
def to_dict(self):
|
| 59 |
+
d = super().to_dict()
|
| 60 |
+
d['entity_type'] = self.entity_type
|
| 61 |
+
return d
|
| 62 |
+
|
| 63 |
+
@classmethod
|
| 64 |
+
def from_dict(cls, data):
|
| 65 |
+
return super().from_dict(data)
|
| 66 |
+
|
| 67 |
+
@register_entity
|
| 68 |
+
@dataclass
|
| 69 |
+
class Parameter(Entity):
|
| 70 |
+
entity_type = 'parameter'
|
| 71 |
+
entity_dtype: str
|
| 72 |
+
|
| 73 |
+
def to_dict(self):
|
| 74 |
+
d = super().to_dict()
|
| 75 |
+
d['entity_type'] = self.entity_type
|
| 76 |
+
return d
|
| 77 |
+
|
| 78 |
+
@classmethod
|
| 79 |
+
def from_dict(cls, data):
|
| 80 |
+
return super().from_dict(data)
|
| 81 |
+
|
| 82 |
+
@register_entity
|
| 83 |
+
@dataclass
|
| 84 |
+
class Method(Entity):
|
| 85 |
+
entity_type = 'method'
|
| 86 |
+
parameters: List['Parameter'] = field(default_factory=list)
|
| 87 |
+
associated_class: Optional['Class'] = None
|
| 88 |
+
|
| 89 |
+
def to_dict(self):
|
| 90 |
+
d = super().to_dict()
|
| 91 |
+
d['parameters'] = _entity_to_dict(self.parameters)
|
| 92 |
+
d['associated_class'] = self.associated_class.to_dict() if self.associated_class else None
|
| 93 |
+
d['entity_type'] = self.entity_type
|
| 94 |
+
return d
|
| 95 |
+
|
| 96 |
+
@classmethod
|
| 97 |
+
def from_dict(cls, data):
|
| 98 |
+
params = [_entity_from_dict(p) for p in data.get('parameters', [])]
|
| 99 |
+
assoc_cls = Class.from_dict(data['associated_class']) if data.get('associated_class') else None
|
| 100 |
+
base = {k: v for k, v in data.items() if k not in ['parameters', 'parameters_pairs', 'associated_class']}
|
| 101 |
+
return cls(parameters=params, associated_class=assoc_cls, **base)
|
| 102 |
+
|
| 103 |
+
@register_entity
|
| 104 |
+
@dataclass
|
| 105 |
+
class Class(Entity):
|
| 106 |
+
entity_type = 'class'
|
| 107 |
+
defined_methods: List['Method'] = field(default_factory=list)
|
| 108 |
+
|
| 109 |
+
def to_dict(self):
|
| 110 |
+
d = super().to_dict()
|
| 111 |
+
d['defined_methods'] = _entity_to_dict(self.defined_methods)
|
| 112 |
+
d['entity_type'] = self.entity_type
|
| 113 |
+
return d
|
| 114 |
+
|
| 115 |
+
@classmethod
|
| 116 |
+
def from_dict(cls, data):
|
| 117 |
+
methods = [_entity_from_dict(m) for m in data.get('defined_methods', [])]
|
| 118 |
+
base = {k: v for k, v in data.items() if k != 'defined_methods'}
|
| 119 |
+
return cls(defined_methods=methods, **base)
|
| 120 |
+
|
| 121 |
+
@register_entity
|
| 122 |
+
@dataclass
|
| 123 |
+
class Function(Entity):
|
| 124 |
+
entity_type = 'function'
|
| 125 |
+
parameters: List[Parameter] = field(default_factory=list)
|
| 126 |
+
parameters_pairs: List[tuple] = field(default_factory=list) # List of (Parameter, Variable)
|
| 127 |
+
|
| 128 |
+
def to_dict(self):
|
| 129 |
+
d = super().to_dict()
|
| 130 |
+
d['parameters'] = _entity_to_dict(self.parameters)
|
| 131 |
+
d['parameters_pairs'] = [ (p.to_dict(), v.to_dict()) for p, v in self.parameters_pairs ]
|
| 132 |
+
d['entity_type'] = self.entity_type
|
| 133 |
+
return d
|
| 134 |
+
|
| 135 |
+
@classmethod
|
| 136 |
+
def from_dict(cls, data):
|
| 137 |
+
params = [_entity_from_dict(p) for p in data.get('parameters', [])]
|
| 138 |
+
parameters_pairs = [(Parameter.from_dict(p), Variable.from_dict(v)) for p, v in data.get('parameters_pairs', [])]
|
| 139 |
+
base = {k: v for k, v in data.items() if k not in ['parameters', 'parameters_pairs']}
|
| 140 |
+
return cls(parameters=params, parameters_pairs=parameters_pairs, **base)
|
| 141 |
+
|
| 142 |
+
@register_entity
|
| 143 |
+
@dataclass
|
| 144 |
+
class FunctionCall(Entity):
|
| 145 |
+
entity_type: str = 'function_call'
|
| 146 |
+
entity_name: str = ''
|
| 147 |
+
defined_chunk_id: str = ''
|
| 148 |
+
entity_dtype: str = ''
|
| 149 |
+
arguments: List[tuple] = field(default_factory=list) # List of (Parameter, Variable)
|
| 150 |
+
associated_functions: Optional[Function] = field(default_factory=list)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def to_dict(self):
|
| 154 |
+
d = super().to_dict()
|
| 155 |
+
d['arguments'] = [ (p.to_dict(), v.to_dict()) for p, v in self.arguments ]
|
| 156 |
+
d['entity_type'] = self.entity_type
|
| 157 |
+
return d
|
| 158 |
+
|
| 159 |
+
@classmethod
|
| 160 |
+
def from_dict(cls, data):
|
| 161 |
+
arguments = [(Parameter.from_dict(p), Variable.from_dict(v)) for p, v in data.get('arguments', [])]
|
| 162 |
+
base = {k: v for k, v in data.items() if k != 'arguments'}
|
| 163 |
+
return cls(arguments=arguments, **base)
|
| 164 |
+
|
| 165 |
+
@register_entity
|
| 166 |
+
@dataclass
|
| 167 |
+
class MethodCall(Entity):
|
| 168 |
+
entity_type: str = 'method_call'
|
| 169 |
+
entity_name: str = ''
|
| 170 |
+
defined_chunk_id: str = ''
|
| 171 |
+
entity_dtype: str = ''
|
| 172 |
+
arguments: List[tuple] = field(default_factory=list) # List of (Parameter, Variable)
|
| 173 |
+
associated_class: Optional[Class] = None
|
| 174 |
+
associated_method: Optional[Method] = None
|
| 175 |
+
|
| 176 |
+
def to_dict(self):
|
| 177 |
+
d = super().to_dict()
|
| 178 |
+
d['arguments'] = [ (p.to_dict(), v.to_dict()) for p, v in self.arguments ]
|
| 179 |
+
d['associated_class'] = self.associated_class.to_dict() if self.associated_class else None
|
| 180 |
+
d['entity_type'] = self.entity_type
|
| 181 |
+
return d
|
| 182 |
+
|
| 183 |
+
@classmethod
|
| 184 |
+
def from_dict(cls, data):
|
| 185 |
+
arguments = [(Parameter.from_dict(p), Variable.from_dict(v)) for p, v in data.get('arguments', [])]
|
| 186 |
+
assoc_cls = Class.from_dict(data['associated_class']) if data.get('associated_class') else None
|
| 187 |
+
base = {k: v for k, v in data.items() if k not in ['arguments', 'associated_class']}
|
| 188 |
+
return cls(arguments=arguments, associated_class=assoc_cls, **base)
|
RepoKnowledgeGraphLib/EntityChunkMapper.py
ADDED
|
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import re
|
| 3 |
+
from typing import List, Tuple, Dict, Any, Set, Optional
|
| 4 |
+
from enum import Enum
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Language(Enum):
|
| 8 |
+
"""Supported programming languages"""
|
| 9 |
+
PYTHON = "python"
|
| 10 |
+
C = "c"
|
| 11 |
+
CPP = "cpp"
|
| 12 |
+
JAVA = "java"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class EntityChunkMapper:
|
| 16 |
+
"""Maps entities from file-level extraction back to their respective chunks"""
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.logger = logging.getLogger("ENTITY_CHUNK_MAPPER")
|
| 20 |
+
self.extension_to_language = {
|
| 21 |
+
'py': Language.PYTHON,
|
| 22 |
+
'pyw': Language.PYTHON,
|
| 23 |
+
'c': Language.C,
|
| 24 |
+
'h': Language.C,
|
| 25 |
+
'cpp': Language.CPP,
|
| 26 |
+
'cc': Language.CPP,
|
| 27 |
+
'cxx': Language.CPP,
|
| 28 |
+
'hpp': Language.CPP,
|
| 29 |
+
'hh': Language.CPP,
|
| 30 |
+
'hxx': Language.CPP,
|
| 31 |
+
'java': Language.JAVA,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
def _detect_language(self, file_name: Optional[str] = None) -> Language:
|
| 35 |
+
"""
|
| 36 |
+
Detect the programming language from file extension
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
file_name: Name of the file (optional)
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Language enum value, defaults to PYTHON if not detected
|
| 43 |
+
"""
|
| 44 |
+
if file_name:
|
| 45 |
+
extension = file_name.split('.')[-1].lower()
|
| 46 |
+
return self.extension_to_language.get(extension, Language.PYTHON)
|
| 47 |
+
return Language.PYTHON
|
| 48 |
+
|
| 49 |
+
def _is_comment_or_docstring(self, line: str, in_docstring: bool, language: Language) -> Tuple[bool, bool]:
|
| 50 |
+
"""
|
| 51 |
+
Check if a line is a comment or part of a docstring/multi-line comment
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
line: The line to check
|
| 55 |
+
in_docstring: Whether we're currently inside a docstring/multi-line comment
|
| 56 |
+
language: The programming language
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Tuple of (is_comment_or_docstring, new_in_docstring_state)
|
| 60 |
+
"""
|
| 61 |
+
stripped = line.strip()
|
| 62 |
+
|
| 63 |
+
if language == Language.PYTHON:
|
| 64 |
+
# Check for single-line comments
|
| 65 |
+
if stripped.startswith('#'):
|
| 66 |
+
return True, in_docstring
|
| 67 |
+
|
| 68 |
+
# Check for docstring delimiters (""" or ''')
|
| 69 |
+
triple_double = '"""'
|
| 70 |
+
triple_single = "'''"
|
| 71 |
+
|
| 72 |
+
# Count occurrences of triple quotes
|
| 73 |
+
if triple_double in stripped or triple_single in stripped:
|
| 74 |
+
# Check if it's a single-line docstring
|
| 75 |
+
if (stripped.count(triple_double) >= 2 or
|
| 76 |
+
stripped.count(triple_single) >= 2):
|
| 77 |
+
# Single-line docstring
|
| 78 |
+
return True, in_docstring
|
| 79 |
+
else:
|
| 80 |
+
# Toggle docstring state
|
| 81 |
+
return True, not in_docstring
|
| 82 |
+
|
| 83 |
+
# If we're in a docstring, this line is part of it
|
| 84 |
+
if in_docstring:
|
| 85 |
+
return True, in_docstring
|
| 86 |
+
|
| 87 |
+
elif language in [Language.C, Language.CPP, Language.JAVA]:
|
| 88 |
+
# Check for single-line comments
|
| 89 |
+
if stripped.startswith('//'):
|
| 90 |
+
return True, in_docstring
|
| 91 |
+
|
| 92 |
+
# Check for multi-line comment delimiters /* */
|
| 93 |
+
if '/*' in line and '*/' in line:
|
| 94 |
+
# Single-line multi-line comment
|
| 95 |
+
return True, in_docstring
|
| 96 |
+
elif '/*' in line:
|
| 97 |
+
# Start of multi-line comment
|
| 98 |
+
return True, True
|
| 99 |
+
elif '*/' in line:
|
| 100 |
+
# End of multi-line comment
|
| 101 |
+
return True, False
|
| 102 |
+
|
| 103 |
+
# If we're in a multi-line comment
|
| 104 |
+
if in_docstring:
|
| 105 |
+
return True, in_docstring
|
| 106 |
+
|
| 107 |
+
return False, in_docstring
|
| 108 |
+
|
| 109 |
+
def _get_code_lines(self, chunk_lines: List[str], language: Language) -> List[str]:
|
| 110 |
+
"""
|
| 111 |
+
Filter out comments and docstrings from chunk lines
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
chunk_lines: List of lines in the chunk
|
| 115 |
+
language: The programming language
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
List of lines that are actual code (not comments or docstrings)
|
| 119 |
+
"""
|
| 120 |
+
code_lines = []
|
| 121 |
+
in_docstring = False
|
| 122 |
+
|
| 123 |
+
for line in chunk_lines:
|
| 124 |
+
is_doc, in_docstring = self._is_comment_or_docstring(line, in_docstring, language)
|
| 125 |
+
if not is_doc:
|
| 126 |
+
code_lines.append(line)
|
| 127 |
+
|
| 128 |
+
return code_lines
|
| 129 |
+
|
| 130 |
+
def _is_valid_identifier_match(self, text: str, identifier: str, position: int) -> bool:
|
| 131 |
+
"""
|
| 132 |
+
Check if an identifier match at a position is valid (not part of another word)
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
text: The text containing the identifier
|
| 136 |
+
identifier: The identifier to check
|
| 137 |
+
position: The position where the identifier was found
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
True if this is a valid standalone identifier match
|
| 141 |
+
"""
|
| 142 |
+
# Check character before (if exists)
|
| 143 |
+
if position > 0:
|
| 144 |
+
char_before = text[position - 1]
|
| 145 |
+
if char_before.isalnum() or char_before == '_':
|
| 146 |
+
return False
|
| 147 |
+
|
| 148 |
+
# Check character after (if exists)
|
| 149 |
+
end_pos = position + len(identifier)
|
| 150 |
+
if end_pos < len(text):
|
| 151 |
+
char_after = text[end_pos]
|
| 152 |
+
if char_after.isalnum() or char_after == '_':
|
| 153 |
+
return False
|
| 154 |
+
|
| 155 |
+
return True
|
| 156 |
+
|
| 157 |
+
def _contains_identifier(self, line: str, identifier: str) -> bool:
|
| 158 |
+
"""
|
| 159 |
+
Check if a line contains an identifier as a standalone word (not part of another word)
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
line: The line to check
|
| 163 |
+
identifier: The identifier to find
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
True if the identifier appears as a standalone word
|
| 167 |
+
"""
|
| 168 |
+
# Use word boundary regex for precise matching
|
| 169 |
+
pattern = r'\b' + re.escape(identifier) + r'\b'
|
| 170 |
+
return bool(re.search(pattern, line))
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def find_entity_in_chunks(self, entity_name: str, chunks: List[str], entity_type: str = None,
|
| 174 |
+
file_name: Optional[str] = None) -> Set[int]:
|
| 175 |
+
"""
|
| 176 |
+
Find which chunks contain a specific entity declaration or call
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
entity_name: Name of the entity to find
|
| 180 |
+
chunks: List of code chunks
|
| 181 |
+
entity_type: Type of entity (class, function, method, variable)
|
| 182 |
+
file_name: Name of the file to detect language (optional)
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
Set of chunk indices that contain this entity
|
| 186 |
+
"""
|
| 187 |
+
matching_chunks = set()
|
| 188 |
+
language = self._detect_language(file_name)
|
| 189 |
+
|
| 190 |
+
# Split the entity name to handle nested entities like "ClassName.method"
|
| 191 |
+
# For Java/C++, also handle :: separator
|
| 192 |
+
if '::' in entity_name:
|
| 193 |
+
parts = entity_name.split('::')
|
| 194 |
+
else:
|
| 195 |
+
parts = entity_name.split('.')
|
| 196 |
+
base_name = parts[-1] # The actual identifier
|
| 197 |
+
|
| 198 |
+
for chunk_idx, chunk in enumerate(chunks):
|
| 199 |
+
chunk_lines = chunk.strip().split('\n')
|
| 200 |
+
|
| 201 |
+
# Look for different patterns based on entity type
|
| 202 |
+
if self._entity_appears_in_chunk(entity_name, base_name, chunk, chunk_lines, entity_type, language):
|
| 203 |
+
matching_chunks.add(chunk_idx)
|
| 204 |
+
|
| 205 |
+
return matching_chunks
|
| 206 |
+
|
| 207 |
+
def _entity_appears_in_chunk(self, full_name: str, base_name: str, chunk: str, chunk_lines: List[str],
|
| 208 |
+
entity_type: str, language: Language) -> bool:
|
| 209 |
+
"""Check if an entity appears in a specific chunk (excluding comments and docstrings)"""
|
| 210 |
+
|
| 211 |
+
# Filter out comments and docstrings
|
| 212 |
+
code_lines = self._get_code_lines(chunk_lines, language)
|
| 213 |
+
|
| 214 |
+
# If no code lines remain, entity doesn't appear in actual code
|
| 215 |
+
if not code_lines:
|
| 216 |
+
return False
|
| 217 |
+
|
| 218 |
+
# Language-specific entity matching
|
| 219 |
+
if language == Language.PYTHON:
|
| 220 |
+
return self._entity_appears_in_python(full_name, base_name, code_lines, entity_type)
|
| 221 |
+
elif language in [Language.C, Language.CPP]:
|
| 222 |
+
return self._entity_appears_in_c_cpp(full_name, base_name, code_lines, entity_type)
|
| 223 |
+
elif language == Language.JAVA:
|
| 224 |
+
return self._entity_appears_in_java(full_name, base_name, code_lines, entity_type)
|
| 225 |
+
|
| 226 |
+
return False
|
| 227 |
+
|
| 228 |
+
def _entity_appears_in_python(self, full_name: str, base_name: str, code_lines: List[str],
|
| 229 |
+
entity_type: str) -> bool:
|
| 230 |
+
"""Check if entity appears in Python code"""
|
| 231 |
+
|
| 232 |
+
if entity_type == "class":
|
| 233 |
+
# Look for class definition
|
| 234 |
+
for line in code_lines:
|
| 235 |
+
stripped = line.strip()
|
| 236 |
+
if re.match(rf'class\s+{re.escape(base_name)}[\s:(]', stripped):
|
| 237 |
+
return True
|
| 238 |
+
|
| 239 |
+
elif entity_type == "api_endpoint":
|
| 240 |
+
# Look for API endpoint definition - the function decorated with @app.get, @app.post, etc.
|
| 241 |
+
# We look for the function definition itself
|
| 242 |
+
for line in code_lines:
|
| 243 |
+
stripped = line.strip()
|
| 244 |
+
# Match the function definition with the endpoint name
|
| 245 |
+
if re.match(rf'(async\s+)?def\s+{re.escape(base_name)}\s*\(', stripped):
|
| 246 |
+
return True
|
| 247 |
+
# Also check for decorators that might reference the endpoint
|
| 248 |
+
if re.search(rf'@\w+\.(get|post|put|delete|patch|options|head)\s*\(', stripped):
|
| 249 |
+
return True
|
| 250 |
+
|
| 251 |
+
elif entity_type == "function":
|
| 252 |
+
# Look for function definition (not method)
|
| 253 |
+
for line in code_lines:
|
| 254 |
+
stripped = line.strip()
|
| 255 |
+
# Check it's not indented (not a method)
|
| 256 |
+
if not line.startswith(" ") and not line.startswith("\t"):
|
| 257 |
+
if re.match(rf'(async\s+)?def\s+{re.escape(base_name)}\s*\(', stripped):
|
| 258 |
+
return True
|
| 259 |
+
|
| 260 |
+
elif entity_type == "method":
|
| 261 |
+
# Look for method definition (indented def)
|
| 262 |
+
method_name = full_name.split('.')[-1]
|
| 263 |
+
for line in code_lines:
|
| 264 |
+
stripped = line.strip()
|
| 265 |
+
# Check it's indented (is a method)
|
| 266 |
+
if line.startswith(" ") or line.startswith("\t"):
|
| 267 |
+
if re.match(rf'(async\s+)?def\s+{re.escape(method_name)}\s*\(', stripped):
|
| 268 |
+
return True
|
| 269 |
+
|
| 270 |
+
elif entity_type == "variable":
|
| 271 |
+
# Look for variable assignment or usage
|
| 272 |
+
if "." in full_name:
|
| 273 |
+
parts = full_name.split('.')
|
| 274 |
+
attr_name = parts[-1]
|
| 275 |
+
for line in code_lines:
|
| 276 |
+
if re.search(rf'\.\s*{re.escape(attr_name)}\b', line):
|
| 277 |
+
return True
|
| 278 |
+
else:
|
| 279 |
+
for line in code_lines:
|
| 280 |
+
stripped = line.strip()
|
| 281 |
+
if re.match(rf'{re.escape(base_name)}\s*[=:]', stripped):
|
| 282 |
+
return True
|
| 283 |
+
|
| 284 |
+
# For called entities, look for usage patterns
|
| 285 |
+
if entity_type in ["function", "method"] or entity_type is None:
|
| 286 |
+
for line in code_lines:
|
| 287 |
+
if re.search(rf'\b{re.escape(base_name)}\s*\(', line):
|
| 288 |
+
return True
|
| 289 |
+
|
| 290 |
+
if entity_type == "class" or entity_type is None:
|
| 291 |
+
for line in code_lines:
|
| 292 |
+
if re.search(rf'\b{re.escape(base_name)}\s*\(', line):
|
| 293 |
+
return True
|
| 294 |
+
|
| 295 |
+
# General usage as identifier
|
| 296 |
+
if entity_type is None or entity_type == "variable":
|
| 297 |
+
for line in code_lines:
|
| 298 |
+
if self._contains_identifier(line, base_name):
|
| 299 |
+
return True
|
| 300 |
+
|
| 301 |
+
return False
|
| 302 |
+
|
| 303 |
+
def _extract_using_namespace_directives(self, code_lines: List[str]) -> List[str]:
|
| 304 |
+
"""
|
| 305 |
+
Extract using namespace directives from C++ code.
|
| 306 |
+
Returns a list of namespace names that are being imported.
|
| 307 |
+
"""
|
| 308 |
+
namespaces = []
|
| 309 |
+
for line in code_lines:
|
| 310 |
+
stripped = line.strip()
|
| 311 |
+
# Match "using namespace <name>;"
|
| 312 |
+
match = re.match(r'using\s+namespace\s+([a-zA-Z_][a-zA-Z0-9_:]*)\s*;', stripped)
|
| 313 |
+
if match:
|
| 314 |
+
namespaces.append(match.group(1))
|
| 315 |
+
return namespaces
|
| 316 |
+
|
| 317 |
+
def _entity_appears_in_c_cpp(self, full_name: str, base_name: str, code_lines: List[str],
|
| 318 |
+
entity_type: str) -> bool:
|
| 319 |
+
"""Check if entity appears in C/C++ code"""
|
| 320 |
+
|
| 321 |
+
# Extract using namespace directives
|
| 322 |
+
using_namespaces = self._extract_using_namespace_directives(code_lines)
|
| 323 |
+
|
| 324 |
+
# Check if the full_name matches any imported namespace + base_name
|
| 325 |
+
# e.g., if full_name is "math::Calculator" and we have "using namespace math",
|
| 326 |
+
# then "Calculator" in code should match
|
| 327 |
+
namespace_match = False
|
| 328 |
+
if '::' in full_name:
|
| 329 |
+
for ns in using_namespaces:
|
| 330 |
+
# Check if full_name starts with this namespace
|
| 331 |
+
if full_name.startswith(ns + '::'):
|
| 332 |
+
namespace_match = True
|
| 333 |
+
break
|
| 334 |
+
|
| 335 |
+
if entity_type == "class":
|
| 336 |
+
# Look for class/struct definition
|
| 337 |
+
for line in code_lines:
|
| 338 |
+
stripped = line.strip()
|
| 339 |
+
if re.match(rf'(class|struct)\s+{re.escape(base_name)}[\s:{{]', stripped):
|
| 340 |
+
return True
|
| 341 |
+
|
| 342 |
+
elif entity_type == "function":
|
| 343 |
+
# Look for function definition or declaration
|
| 344 |
+
for line in code_lines:
|
| 345 |
+
stripped = line.strip()
|
| 346 |
+
# Match function patterns: return_type function_name(
|
| 347 |
+
# Also handle constructors and destructors
|
| 348 |
+
if (re.search(rf'\b{re.escape(base_name)}\s*\(', stripped) and
|
| 349 |
+
not stripped.startswith('//')):
|
| 350 |
+
# Additional check: likely a function if followed by parameters
|
| 351 |
+
return True
|
| 352 |
+
|
| 353 |
+
elif entity_type == "method":
|
| 354 |
+
# Look for method definition (with class scope)
|
| 355 |
+
method_name = full_name.split('::')[-1] if '::' in full_name else full_name.split('.')[-1]
|
| 356 |
+
for line in code_lines:
|
| 357 |
+
stripped = line.strip()
|
| 358 |
+
# Match ClassName::methodName( or just methodName( inside class
|
| 359 |
+
if re.search(rf'\b{re.escape(method_name)}\s*\(', stripped):
|
| 360 |
+
return True
|
| 361 |
+
|
| 362 |
+
elif entity_type == "variable":
|
| 363 |
+
# Look for variable declaration or usage
|
| 364 |
+
for line in code_lines:
|
| 365 |
+
stripped = line.strip()
|
| 366 |
+
# Match variable declarations and assignments
|
| 367 |
+
if re.search(rf'\b{re.escape(base_name)}\b', stripped):
|
| 368 |
+
return True
|
| 369 |
+
|
| 370 |
+
# For called entities, look for usage patterns
|
| 371 |
+
if entity_type in ["function", "method"] or entity_type is None:
|
| 372 |
+
for line in code_lines:
|
| 373 |
+
if re.search(rf'\b{re.escape(base_name)}\s*\(', line):
|
| 374 |
+
return True
|
| 375 |
+
|
| 376 |
+
if entity_type == "class" or entity_type is None:
|
| 377 |
+
# Look for instantiation or usage
|
| 378 |
+
for line in code_lines:
|
| 379 |
+
if re.search(rf'\b{re.escape(base_name)}\b', line):
|
| 380 |
+
# If we found base_name and there's a namespace match, this is a match
|
| 381 |
+
if namespace_match:
|
| 382 |
+
return True
|
| 383 |
+
# If full_name doesn't have a namespace, it's a direct match
|
| 384 |
+
if '::' not in full_name:
|
| 385 |
+
return True
|
| 386 |
+
|
| 387 |
+
# General usage as identifier
|
| 388 |
+
if entity_type is None or entity_type == "variable":
|
| 389 |
+
for line in code_lines:
|
| 390 |
+
if self._contains_identifier(line, base_name):
|
| 391 |
+
# If we found base_name and there's a namespace match, this is a match
|
| 392 |
+
if namespace_match:
|
| 393 |
+
return True
|
| 394 |
+
# If full_name doesn't have a namespace, it's a direct match
|
| 395 |
+
if '::' not in full_name:
|
| 396 |
+
return True
|
| 397 |
+
|
| 398 |
+
return False
|
| 399 |
+
|
| 400 |
+
def _entity_appears_in_java(self, full_name: str, base_name: str, code_lines: List[str],
|
| 401 |
+
entity_type: str) -> bool:
|
| 402 |
+
"""Check if entity appears in Java code"""
|
| 403 |
+
|
| 404 |
+
if entity_type == "class":
|
| 405 |
+
# Look for class/interface/enum definition
|
| 406 |
+
for line in code_lines:
|
| 407 |
+
stripped = line.strip()
|
| 408 |
+
if re.match(rf'(public|private|protected)?\s*(class|interface|enum)\s+{re.escape(base_name)}[\s<{{]', stripped):
|
| 409 |
+
return True
|
| 410 |
+
# Without modifier
|
| 411 |
+
if re.match(rf'(class|interface|enum)\s+{re.escape(base_name)}[\s<{{]', stripped):
|
| 412 |
+
return True
|
| 413 |
+
|
| 414 |
+
elif entity_type == "api_endpoint":
|
| 415 |
+
# Look for API endpoint definition - the method with Spring annotations
|
| 416 |
+
# Extract just the method name from the full qualified name (e.g., "com.example.Controller::method" -> "method")
|
| 417 |
+
method_name = base_name.split('::')[-1] if '::' in base_name else base_name
|
| 418 |
+
for line in code_lines:
|
| 419 |
+
stripped = line.strip()
|
| 420 |
+
# Match the method definition
|
| 421 |
+
if re.search(rf'\b{re.escape(method_name)}\s*\(', stripped):
|
| 422 |
+
return True
|
| 423 |
+
# Also check for Spring annotations
|
| 424 |
+
if re.search(r'@(GetMapping|PostMapping|PutMapping|DeleteMapping|PatchMapping|RequestMapping)', stripped):
|
| 425 |
+
return True
|
| 426 |
+
|
| 427 |
+
elif entity_type == "function":
|
| 428 |
+
# In Java, functions are methods
|
| 429 |
+
for line in code_lines:
|
| 430 |
+
stripped = line.strip()
|
| 431 |
+
# Match method signature patterns
|
| 432 |
+
if re.search(rf'\b{re.escape(base_name)}\s*\(', stripped):
|
| 433 |
+
return True
|
| 434 |
+
|
| 435 |
+
elif entity_type == "method":
|
| 436 |
+
# Look for method definition
|
| 437 |
+
method_name = full_name.split('.')[-1]
|
| 438 |
+
for line in code_lines:
|
| 439 |
+
stripped = line.strip()
|
| 440 |
+
if re.search(rf'\b{re.escape(method_name)}\s*\(', stripped):
|
| 441 |
+
return True
|
| 442 |
+
|
| 443 |
+
elif entity_type == "variable":
|
| 444 |
+
# Look for variable declaration or usage
|
| 445 |
+
for line in code_lines:
|
| 446 |
+
stripped = line.strip()
|
| 447 |
+
if re.search(rf'\b{re.escape(base_name)}\b', stripped):
|
| 448 |
+
return True
|
| 449 |
+
|
| 450 |
+
# For called entities, look for usage patterns
|
| 451 |
+
if entity_type in ["function", "method"] or entity_type is None:
|
| 452 |
+
for line in code_lines:
|
| 453 |
+
if re.search(rf'\b{re.escape(base_name)}\s*\(', line):
|
| 454 |
+
return True
|
| 455 |
+
|
| 456 |
+
if entity_type == "class" or entity_type is None:
|
| 457 |
+
# Look for instantiation (new ClassName) or usage
|
| 458 |
+
for line in code_lines:
|
| 459 |
+
if re.search(rf'\b{re.escape(base_name)}\b', line):
|
| 460 |
+
return True
|
| 461 |
+
|
| 462 |
+
# General usage as identifier
|
| 463 |
+
if entity_type is None or entity_type == "variable":
|
| 464 |
+
for line in code_lines:
|
| 465 |
+
if self._contains_identifier(line, base_name):
|
| 466 |
+
return True
|
| 467 |
+
|
| 468 |
+
return False
|
| 469 |
+
|
| 470 |
+
def map_entities_to_chunks(self, declared_entities: List[Dict[str, Any]],
|
| 471 |
+
called_entities: List[str],
|
| 472 |
+
chunks: List[str],
|
| 473 |
+
file_name: Optional[str] = None) -> Tuple[Dict[int, List[Dict[str, Any]]],
|
| 474 |
+
Dict[int, List[str]]]:
|
| 475 |
+
"""
|
| 476 |
+
Map file-level entities back to their respective chunks
|
| 477 |
+
|
| 478 |
+
Args:
|
| 479 |
+
declared_entities: List of declared entities from file-level extraction
|
| 480 |
+
called_entities: List of called entities from file-level extraction
|
| 481 |
+
chunks: List of code chunks
|
| 482 |
+
file_name: Name of the file to detect language (optional)
|
| 483 |
+
|
| 484 |
+
Returns:
|
| 485 |
+
Tuple of (chunk_declared_entities, chunk_called_entities)
|
| 486 |
+
- chunk_declared_entities: Dict mapping chunk_index -> list of declared entities
|
| 487 |
+
- chunk_called_entities: Dict mapping chunk_index -> list of called entities
|
| 488 |
+
"""
|
| 489 |
+
chunk_declared = {}
|
| 490 |
+
chunk_called = {}
|
| 491 |
+
|
| 492 |
+
# Initialize empty lists for all chunks
|
| 493 |
+
for i in range(len(chunks)):
|
| 494 |
+
chunk_declared[i] = []
|
| 495 |
+
chunk_called[i] = []
|
| 496 |
+
|
| 497 |
+
# Map declared entities to chunks
|
| 498 |
+
for entity in declared_entities:
|
| 499 |
+
entity_name = entity.get("name", "")
|
| 500 |
+
entity_type = entity.get("type", "")
|
| 501 |
+
|
| 502 |
+
matching_chunks = self.find_entity_in_chunks(entity_name, chunks, entity_type, file_name)
|
| 503 |
+
|
| 504 |
+
# Add entity to matching chunks
|
| 505 |
+
for chunk_idx in matching_chunks:
|
| 506 |
+
chunk_declared[chunk_idx].append(entity)
|
| 507 |
+
|
| 508 |
+
# Map called entities to chunks
|
| 509 |
+
for called_entity in called_entities:
|
| 510 |
+
matching_chunks = self.find_entity_in_chunks(called_entity, chunks, None, file_name)
|
| 511 |
+
|
| 512 |
+
# Add called entity to matching chunks
|
| 513 |
+
for chunk_idx in matching_chunks:
|
| 514 |
+
if called_entity not in chunk_called[chunk_idx]:
|
| 515 |
+
chunk_called[chunk_idx].append(called_entity)
|
| 516 |
+
|
| 517 |
+
return chunk_declared, chunk_called
|
RepoKnowledgeGraphLib/EntityExtractor.py
ADDED
|
@@ -0,0 +1,2032 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
import tempfile
|
| 5 |
+
from typing import List, Dict, Any, Tuple, Optional
|
| 6 |
+
from clang import cindex
|
| 7 |
+
import javalang
|
| 8 |
+
import javalang.tree as T
|
| 9 |
+
import esprima
|
| 10 |
+
from bs4 import BeautifulSoup
|
| 11 |
+
import tree_sitter_rust as ts_rust
|
| 12 |
+
from tree_sitter import Language, Parser
|
| 13 |
+
import re
|
| 14 |
+
from .utils.path_utils import generate_entity_aliases
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
LOGGER_NAME = "AST_ENTITY_EXTRACTOR"
|
| 19 |
+
logger = logging.getLogger(LOGGER_NAME)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class BaseASTEntityExtractor:
|
| 23 |
+
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
|
| 24 |
+
"""
|
| 25 |
+
Extract entities from source code.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
code: Source code as string
|
| 29 |
+
file_path: Optional path to the source file (for better context and include resolution)
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Tuple of (declared_entities, called_entities)
|
| 33 |
+
"""
|
| 34 |
+
raise NotImplementedError
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Add a reset contract so extractors can be reused safely
|
| 38 |
+
def reset(self) -> None:
|
| 39 |
+
"""
|
| 40 |
+
Reset internal state so the extractor instance can be reused.
|
| 41 |
+
Concrete extractors should override this to clear their buffers.
|
| 42 |
+
"""
|
| 43 |
+
raise NotImplementedError
|
| 44 |
+
|
| 45 |
+
class HTMLEntityExtractor(BaseASTEntityExtractor):
|
| 46 |
+
"""
|
| 47 |
+
Hybrid HTML AST-based entity extractor.
|
| 48 |
+
|
| 49 |
+
Responsibilities:
|
| 50 |
+
β’ Parse HTML into a tree
|
| 51 |
+
β’ Extract declared DOM entities (ids, names, classes)
|
| 52 |
+
β’ Extract JavaScript calls from inline event handlers
|
| 53 |
+
β’ Extract JS entities from <script> tags
|
| 54 |
+
β’ Integrate cleanly with the hybrid AST graph linker
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
EVENT_ATTR_PREFIX = "on" # e.g., onclick, onsubmit, etc.
|
| 58 |
+
|
| 59 |
+
def __init__(self):
|
| 60 |
+
self.js_extractor = JavaScriptEntityExtractor()
|
| 61 |
+
self.reset()
|
| 62 |
+
|
| 63 |
+
# --------------------------------------
|
| 64 |
+
# Core interface
|
| 65 |
+
# --------------------------------------
|
| 66 |
+
def reset(self):
|
| 67 |
+
self.declared_entities: List[Dict[str, str]] = []
|
| 68 |
+
self.called_entities: List[str] = []
|
| 69 |
+
|
| 70 |
+
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, str]], List[str]]:
|
| 71 |
+
"""Main entry point: parse HTML and extract entities."""
|
| 72 |
+
self.reset()
|
| 73 |
+
try:
|
| 74 |
+
soup = BeautifulSoup(code, "html.parser")
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(f"[HTMLEntityExtractor] Parsing error: {e}")
|
| 77 |
+
return [], []
|
| 78 |
+
|
| 79 |
+
# --- DOM element declarations ---
|
| 80 |
+
for tag in soup.find_all(True):
|
| 81 |
+
self._handle_tag_declaration(tag)
|
| 82 |
+
self._handle_event_attributes(tag)
|
| 83 |
+
|
| 84 |
+
# --- <script> tags (inline + external) ---
|
| 85 |
+
for script in soup.find_all("script"):
|
| 86 |
+
self._handle_script(script)
|
| 87 |
+
|
| 88 |
+
# --- Deduplication ---
|
| 89 |
+
self.declared_entities = self._deduplicate_dicts(self.declared_entities)
|
| 90 |
+
self.called_entities = self._deduplicate_list(self.called_entities)
|
| 91 |
+
|
| 92 |
+
return self.declared_entities, self.called_entities
|
| 93 |
+
|
| 94 |
+
# --------------------------------------
|
| 95 |
+
# Tag & attribute handlers
|
| 96 |
+
# --------------------------------------
|
| 97 |
+
def _handle_tag_declaration(self, tag):
|
| 98 |
+
"""Extract declared DOM elements (id, name, class)."""
|
| 99 |
+
if tag.has_attr("id"):
|
| 100 |
+
self.declared_entities.append({"name": tag["id"], "type": "element"})
|
| 101 |
+
|
| 102 |
+
if tag.has_attr("name"):
|
| 103 |
+
self.declared_entities.append({"name": tag["name"], "type": "element"})
|
| 104 |
+
|
| 105 |
+
if tag.has_attr("class"):
|
| 106 |
+
classes = tag["class"]
|
| 107 |
+
if isinstance(classes, list):
|
| 108 |
+
for c in classes:
|
| 109 |
+
self.declared_entities.append({"name": c, "type": "class"})
|
| 110 |
+
elif isinstance(classes, str):
|
| 111 |
+
self.declared_entities.append({"name": classes, "type": "class"})
|
| 112 |
+
|
| 113 |
+
def _handle_event_attributes(self, tag):
|
| 114 |
+
"""Extract JS calls from inline event attributes."""
|
| 115 |
+
if not self.js_extractor:
|
| 116 |
+
return
|
| 117 |
+
for attr, value in tag.attrs.items():
|
| 118 |
+
if attr.lower().startswith(self.EVENT_ATTR_PREFIX) and isinstance(value, str):
|
| 119 |
+
try:
|
| 120 |
+
_, called = self.js_extractor.extract_entities(value)
|
| 121 |
+
self.called_entities.extend(called)
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f"[HTMLEntityExtractor] JS parse error in {attr}: {e}")
|
| 124 |
+
|
| 125 |
+
def _handle_script(self, script):
|
| 126 |
+
"""Extract JS entities from <script> blocks or src attributes."""
|
| 127 |
+
if script.has_attr("src"):
|
| 128 |
+
src = script["src"]
|
| 129 |
+
self.called_entities.append(src)
|
| 130 |
+
return
|
| 131 |
+
|
| 132 |
+
if not self.js_extractor:
|
| 133 |
+
return
|
| 134 |
+
|
| 135 |
+
js_code = (script.string or "").strip()
|
| 136 |
+
if js_code:
|
| 137 |
+
try:
|
| 138 |
+
declared, called = self.js_extractor.extract_entities(js_code)
|
| 139 |
+
self.declared_entities.extend(declared)
|
| 140 |
+
self.called_entities.extend(called)
|
| 141 |
+
except Exception as e:
|
| 142 |
+
print(f"[HTMLEntityExtractor] JS parse error in <script>: {e}")
|
| 143 |
+
|
| 144 |
+
# --------------------------------------
|
| 145 |
+
# Helpers
|
| 146 |
+
# --------------------------------------
|
| 147 |
+
@staticmethod
|
| 148 |
+
def _deduplicate_dicts(dicts: List[Dict]) -> List[Dict]:
|
| 149 |
+
seen = set()
|
| 150 |
+
result = []
|
| 151 |
+
for d in dicts:
|
| 152 |
+
key = tuple(sorted(d.items()))
|
| 153 |
+
if key not in seen:
|
| 154 |
+
seen.add(key)
|
| 155 |
+
result.append(d)
|
| 156 |
+
return result
|
| 157 |
+
|
| 158 |
+
@staticmethod
|
| 159 |
+
def _deduplicate_list(items: List[str]) -> List[str]:
|
| 160 |
+
seen = set()
|
| 161 |
+
result = []
|
| 162 |
+
for i in items:
|
| 163 |
+
if i not in seen:
|
| 164 |
+
seen.add(i)
|
| 165 |
+
result.append(i)
|
| 166 |
+
return result
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class JavaEntityExtractor(BaseASTEntityExtractor):
|
| 170 |
+
"""
|
| 171 |
+
Extract declared and called entities from Java code using javalang.
|
| 172 |
+
Produces the same (declared_entities, called_entities) structure as other extractors.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
def __init__(self):
|
| 176 |
+
self.reset()
|
| 177 |
+
|
| 178 |
+
def reset(self) -> None:
|
| 179 |
+
self.declared_entities: List[Dict[str, Any]] = []
|
| 180 |
+
self.called_entities: List[str] = []
|
| 181 |
+
self.current_package: Optional[str] = None
|
| 182 |
+
self.scope_stack: List[str] = []
|
| 183 |
+
self.api_endpoints: List[Dict[str, Any]] = [] # Track API endpoint definitions
|
| 184 |
+
self.current_class_base_path: Optional[str] = None # For @RequestMapping on class
|
| 185 |
+
|
| 186 |
+
# -----------------------------------------------------------
|
| 187 |
+
# Helpers
|
| 188 |
+
# -----------------------------------------------------------
|
| 189 |
+
|
| 190 |
+
def _qualified(self, name: str) -> str:
|
| 191 |
+
if not name:
|
| 192 |
+
return ""
|
| 193 |
+
scope = "::".join(self.scope_stack)
|
| 194 |
+
return f"{scope}::{name}" if scope else name
|
| 195 |
+
|
| 196 |
+
def _walk_type(self, t):
|
| 197 |
+
"""Return string representation of a type node."""
|
| 198 |
+
if not t:
|
| 199 |
+
return "unknown"
|
| 200 |
+
if isinstance(t, str):
|
| 201 |
+
return t
|
| 202 |
+
if hasattr(t, "name"):
|
| 203 |
+
name = t.name
|
| 204 |
+
if getattr(t, "arguments", None):
|
| 205 |
+
args = [self._walk_type(a.type) for a in t.arguments if hasattr(a, "type")]
|
| 206 |
+
name += "<" + ", ".join(args) + ">"
|
| 207 |
+
return name
|
| 208 |
+
return "unknown"
|
| 209 |
+
|
| 210 |
+
# -----------------------------------------------------------
|
| 211 |
+
# Main AST traversal
|
| 212 |
+
# -----------------------------------------------------------
|
| 213 |
+
|
| 214 |
+
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
|
| 215 |
+
self.reset()
|
| 216 |
+
|
| 217 |
+
try:
|
| 218 |
+
tree = javalang.parse.parse(code)
|
| 219 |
+
except javalang.parser.JavaSyntaxError as e:
|
| 220 |
+
logger.error(f"Syntax error in Java code: {e}")
|
| 221 |
+
return [], []
|
| 222 |
+
except Exception as e:
|
| 223 |
+
logger.error(f"Error parsing Java code: {e}", exc_info=True)
|
| 224 |
+
return [], []
|
| 225 |
+
|
| 226 |
+
# --- Package ---
|
| 227 |
+
if tree.package:
|
| 228 |
+
self.current_package = tree.package.name
|
| 229 |
+
|
| 230 |
+
# --- Imports ---
|
| 231 |
+
for imp in tree.imports:
|
| 232 |
+
self.called_entities.append(imp.path)
|
| 233 |
+
|
| 234 |
+
# --- Types (classes, interfaces, enums) ---
|
| 235 |
+
for type_decl in tree.types:
|
| 236 |
+
self._visit_type(type_decl)
|
| 237 |
+
|
| 238 |
+
# Deduplicate
|
| 239 |
+
seen_decl = set()
|
| 240 |
+
unique_declared = []
|
| 241 |
+
for e in self.declared_entities:
|
| 242 |
+
key = (e.get("name"), e.get("type"), e.get("dtype"))
|
| 243 |
+
if key not in seen_decl:
|
| 244 |
+
unique_declared.append(e)
|
| 245 |
+
seen_decl.add(key)
|
| 246 |
+
|
| 247 |
+
unique_called = list(dict.fromkeys(self.called_entities))
|
| 248 |
+
return unique_declared, unique_called
|
| 249 |
+
|
| 250 |
+
# -----------------------------------------------------------
|
| 251 |
+
# Visitors for different node types
|
| 252 |
+
# -----------------------------------------------------------
|
| 253 |
+
|
| 254 |
+
def _visit_type(self, node):
|
| 255 |
+
if isinstance(node, javalang.tree.ClassDeclaration):
|
| 256 |
+
self._visit_class(node)
|
| 257 |
+
elif isinstance(node, javalang.tree.InterfaceDeclaration):
|
| 258 |
+
self._visit_interface(node)
|
| 259 |
+
elif isinstance(node, javalang.tree.EnumDeclaration):
|
| 260 |
+
self._visit_enum(node)
|
| 261 |
+
|
| 262 |
+
def _visit_class(self, node):
|
| 263 |
+
full_name = node.name
|
| 264 |
+
if self.current_package:
|
| 265 |
+
full_name = f"{self.current_package}.{node.name}"
|
| 266 |
+
qualified = self._qualified(full_name)
|
| 267 |
+
|
| 268 |
+
self.declared_entities.append({"name": qualified, "type": "class"})
|
| 269 |
+
|
| 270 |
+
# Check for REST controller annotations and extract base path
|
| 271 |
+
old_base_path = self.current_class_base_path
|
| 272 |
+
if node.annotations:
|
| 273 |
+
for annotation in node.annotations:
|
| 274 |
+
if annotation.name in {'RestController', 'Controller'}:
|
| 275 |
+
# Mark as REST controller
|
| 276 |
+
pass
|
| 277 |
+
elif annotation.name == 'RequestMapping':
|
| 278 |
+
# Extract base path from class-level @RequestMapping
|
| 279 |
+
self.current_class_base_path = self._extract_path_from_annotation(annotation)
|
| 280 |
+
|
| 281 |
+
# Inheritance
|
| 282 |
+
if node.extends:
|
| 283 |
+
self.called_entities.append(self._walk_type(node.extends))
|
| 284 |
+
for impl in node.implements or []:
|
| 285 |
+
self.called_entities.append(self._walk_type(impl))
|
| 286 |
+
|
| 287 |
+
self.scope_stack.append(full_name)
|
| 288 |
+
for member in node.body:
|
| 289 |
+
self._visit_member(member)
|
| 290 |
+
self.scope_stack.pop()
|
| 291 |
+
|
| 292 |
+
# Restore the previous base path
|
| 293 |
+
self.current_class_base_path = old_base_path
|
| 294 |
+
|
| 295 |
+
def _visit_interface(self, node):
|
| 296 |
+
full_name = node.name
|
| 297 |
+
if self.current_package:
|
| 298 |
+
full_name = f"{self.current_package}.{node.name}"
|
| 299 |
+
qualified = self._qualified(full_name)
|
| 300 |
+
self.declared_entities.append({"name": qualified, "type": "interface"})
|
| 301 |
+
|
| 302 |
+
for impl in node.extends or []:
|
| 303 |
+
self.called_entities.append(self._walk_type(impl))
|
| 304 |
+
|
| 305 |
+
self.scope_stack.append(full_name)
|
| 306 |
+
for member in node.body:
|
| 307 |
+
self._visit_member(member)
|
| 308 |
+
self.scope_stack.pop()
|
| 309 |
+
|
| 310 |
+
def _visit_enum(self, node):
|
| 311 |
+
full_name = node.name
|
| 312 |
+
if self.current_package:
|
| 313 |
+
full_name = f"{self.current_package}.{node.name}"
|
| 314 |
+
qualified = self._qualified(full_name)
|
| 315 |
+
self.declared_entities.append({"name": qualified, "type": "enum"})
|
| 316 |
+
|
| 317 |
+
def _visit_member(self, node):
|
| 318 |
+
|
| 319 |
+
# --- Method ---
|
| 320 |
+
if isinstance(node, T.MethodDeclaration):
|
| 321 |
+
method_name = self._qualified(node.name)
|
| 322 |
+
|
| 323 |
+
# Check for API endpoint annotations
|
| 324 |
+
api_info = self._extract_api_endpoint_from_annotations(node)
|
| 325 |
+
if api_info:
|
| 326 |
+
self.declared_entities.append({
|
| 327 |
+
"name": method_name,
|
| 328 |
+
"type": "api_endpoint",
|
| 329 |
+
"endpoint": api_info.get("endpoint"),
|
| 330 |
+
"methods": api_info.get("methods")
|
| 331 |
+
})
|
| 332 |
+
self.api_endpoints.append({**api_info, "function": method_name})
|
| 333 |
+
else:
|
| 334 |
+
self.declared_entities.append({"name": method_name, "type": "method"})
|
| 335 |
+
|
| 336 |
+
for param in node.parameters:
|
| 337 |
+
ptype = self._walk_type(param.type)
|
| 338 |
+
pname = f"{method_name}.{param.name}"
|
| 339 |
+
self.declared_entities.append({
|
| 340 |
+
"name": pname,
|
| 341 |
+
"type": "variable",
|
| 342 |
+
"dtype": ptype
|
| 343 |
+
})
|
| 344 |
+
|
| 345 |
+
# Look for method calls in the body
|
| 346 |
+
if node.body:
|
| 347 |
+
self._find_calls(node.body)
|
| 348 |
+
|
| 349 |
+
# --- Constructor ---
|
| 350 |
+
elif isinstance(node, T.ConstructorDeclaration):
|
| 351 |
+
ctor_name = self._qualified(node.name)
|
| 352 |
+
self.declared_entities.append({"name": ctor_name, "type": "constructor"})
|
| 353 |
+
for param in node.parameters:
|
| 354 |
+
ptype = self._walk_type(param.type)
|
| 355 |
+
pname = f"{ctor_name}.{param.name}"
|
| 356 |
+
self.declared_entities.append({
|
| 357 |
+
"name": pname,
|
| 358 |
+
"type": "variable",
|
| 359 |
+
"dtype": ptype
|
| 360 |
+
})
|
| 361 |
+
if node.body:
|
| 362 |
+
self._find_calls(node.body)
|
| 363 |
+
|
| 364 |
+
# --- Field ---
|
| 365 |
+
elif isinstance(node, T.FieldDeclaration):
|
| 366 |
+
dtype = self._walk_type(node.type)
|
| 367 |
+
for decl in node.declarators:
|
| 368 |
+
var_name = self._qualified(decl.name)
|
| 369 |
+
self.declared_entities.append({
|
| 370 |
+
"name": var_name,
|
| 371 |
+
"type": "variable",
|
| 372 |
+
"dtype": dtype
|
| 373 |
+
})
|
| 374 |
+
|
| 375 |
+
# --- Nested class/interface ---
|
| 376 |
+
elif isinstance(node, (T.ClassDeclaration, T.InterfaceDeclaration)):
|
| 377 |
+
self._visit_type(node)
|
| 378 |
+
|
| 379 |
+
# -----------------------------------------------------------
|
| 380 |
+
# API Endpoint Detection
|
| 381 |
+
# -----------------------------------------------------------
|
| 382 |
+
|
| 383 |
+
def _extract_api_endpoint_from_annotations(self, method) -> Optional[Dict[str, Any]]:
|
| 384 |
+
"""
|
| 385 |
+
Extract API endpoint information from Spring Boot method annotations.
|
| 386 |
+
Handles: @GetMapping, @PostMapping, @RequestMapping, etc.
|
| 387 |
+
"""
|
| 388 |
+
if not method.annotations:
|
| 389 |
+
return None
|
| 390 |
+
|
| 391 |
+
for annotation in method.annotations:
|
| 392 |
+
annotation_name = annotation.name
|
| 393 |
+
|
| 394 |
+
if annotation_name in {'GetMapping', 'PostMapping', 'PutMapping', 'PatchMapping', 'DeleteMapping'}:
|
| 395 |
+
# Extract HTTP method from annotation name
|
| 396 |
+
http_method = annotation_name.replace('Mapping', '').upper()
|
| 397 |
+
path = self._extract_path_from_annotation(annotation)
|
| 398 |
+
|
| 399 |
+
if path:
|
| 400 |
+
# Combine with class-level base path if present
|
| 401 |
+
full_path = self._combine_paths(self.current_class_base_path, path)
|
| 402 |
+
return {
|
| 403 |
+
"endpoint": full_path,
|
| 404 |
+
"methods": [http_method],
|
| 405 |
+
"type": "api_endpoint_definition"
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
elif annotation_name == 'RequestMapping':
|
| 409 |
+
# @RequestMapping can specify multiple methods
|
| 410 |
+
path = self._extract_path_from_annotation(annotation)
|
| 411 |
+
methods = self._extract_methods_from_annotation(annotation)
|
| 412 |
+
|
| 413 |
+
if path:
|
| 414 |
+
full_path = self._combine_paths(self.current_class_base_path, path)
|
| 415 |
+
return {
|
| 416 |
+
"endpoint": full_path,
|
| 417 |
+
"methods": methods if methods else ['GET'], # Default to GET
|
| 418 |
+
"type": "api_endpoint_definition"
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
return None
|
| 422 |
+
|
| 423 |
+
def _extract_path_from_annotation(self, annotation) -> Optional[str]:
|
| 424 |
+
"""Extract path/value from Spring annotation."""
|
| 425 |
+
if not annotation.element:
|
| 426 |
+
return None
|
| 427 |
+
|
| 428 |
+
# Handle @GetMapping("/path") - single value
|
| 429 |
+
if isinstance(annotation.element, T.Literal):
|
| 430 |
+
return annotation.element.value.strip('"')
|
| 431 |
+
|
| 432 |
+
# Handle @RequestMapping(value = "/path") or @RequestMapping(path = "/path")
|
| 433 |
+
if isinstance(annotation.element, list):
|
| 434 |
+
for elem in annotation.element:
|
| 435 |
+
if isinstance(elem, T.ElementValuePair):
|
| 436 |
+
if elem.name in {'value', 'path'}:
|
| 437 |
+
if isinstance(elem.value, T.Literal):
|
| 438 |
+
return elem.value.value.strip('"')
|
| 439 |
+
elif isinstance(elem.value, T.ElementArrayValue):
|
| 440 |
+
# Handle array: value = {"/path1", "/path2"}
|
| 441 |
+
if elem.value.values:
|
| 442 |
+
first_val = elem.value.values[0]
|
| 443 |
+
if isinstance(first_val, T.Literal):
|
| 444 |
+
return first_val.value.strip('"')
|
| 445 |
+
|
| 446 |
+
return None
|
| 447 |
+
|
| 448 |
+
def _extract_methods_from_annotation(self, annotation) -> List[str]:
|
| 449 |
+
"""Extract HTTP methods from @RequestMapping annotation."""
|
| 450 |
+
methods = []
|
| 451 |
+
|
| 452 |
+
if isinstance(annotation.element, list):
|
| 453 |
+
for elem in annotation.element:
|
| 454 |
+
if isinstance(elem, T.ElementValuePair):
|
| 455 |
+
if elem.name == 'method':
|
| 456 |
+
# Handle method = RequestMethod.GET or method = {RequestMethod.GET, RequestMethod.POST}
|
| 457 |
+
if hasattr(elem.value, 'member'):
|
| 458 |
+
# Single method: RequestMethod.GET
|
| 459 |
+
methods.append(elem.value.member)
|
| 460 |
+
elif isinstance(elem.value, T.ElementArrayValue):
|
| 461 |
+
# Multiple methods: {RequestMethod.GET, RequestMethod.POST}
|
| 462 |
+
for val in elem.value.values:
|
| 463 |
+
if hasattr(val, 'member'):
|
| 464 |
+
methods.append(val.member)
|
| 465 |
+
|
| 466 |
+
return methods
|
| 467 |
+
|
| 468 |
+
def _combine_paths(self, base_path: Optional[str], path: str) -> str:
|
| 469 |
+
"""Combine base path from class annotation with method path."""
|
| 470 |
+
if not base_path:
|
| 471 |
+
return path
|
| 472 |
+
|
| 473 |
+
# Normalize paths
|
| 474 |
+
base = base_path.rstrip('/')
|
| 475 |
+
path = path.lstrip('/')
|
| 476 |
+
|
| 477 |
+
return f"{base}/{path}" if path else base
|
| 478 |
+
|
| 479 |
+
# -----------------------------------------------------------
|
| 480 |
+
# Find method invocations
|
| 481 |
+
# -----------------------------------------------------------
|
| 482 |
+
|
| 483 |
+
def _find_calls(self, statements):
|
| 484 |
+
"""Recursively find method and constructor calls inside Java AST nodes."""
|
| 485 |
+
|
| 486 |
+
def _recurse(node):
|
| 487 |
+
if isinstance(node, T.MethodInvocation):
|
| 488 |
+
if node.qualifier:
|
| 489 |
+
self.called_entities.append(f"{node.qualifier}.{node.member}")
|
| 490 |
+
else:
|
| 491 |
+
self.called_entities.append(node.member)
|
| 492 |
+
elif isinstance(node, T.ClassCreator):
|
| 493 |
+
self.called_entities.append(self._walk_type(node.type))
|
| 494 |
+
|
| 495 |
+
# Recurse into all children
|
| 496 |
+
if hasattr(node, '__dict__'):
|
| 497 |
+
for attr, val in vars(node).items():
|
| 498 |
+
if isinstance(val, list):
|
| 499 |
+
for child in val:
|
| 500 |
+
if isinstance(child, T.Node):
|
| 501 |
+
_recurse(child)
|
| 502 |
+
elif isinstance(val, T.Node):
|
| 503 |
+
_recurse(val)
|
| 504 |
+
|
| 505 |
+
if not statements:
|
| 506 |
+
return
|
| 507 |
+
|
| 508 |
+
if isinstance(statements, list):
|
| 509 |
+
for stmt in statements:
|
| 510 |
+
_recurse(stmt)
|
| 511 |
+
else:
|
| 512 |
+
_recurse(statements)
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
class JavaScriptEntityExtractor(BaseASTEntityExtractor):
|
| 516 |
+
"""
|
| 517 |
+
Extract declared and called entities from JavaScript code using esprima.
|
| 518 |
+
Handles ES6+ syntax including classes, arrow functions, imports/exports.
|
| 519 |
+
Also detects API endpoint calls (fetch, axios, etc.).
|
| 520 |
+
"""
|
| 521 |
+
|
| 522 |
+
# Common HTTP methods to detect
|
| 523 |
+
HTTP_METHODS = {'get', 'post', 'put', 'patch', 'delete', 'head', 'options'}
|
| 524 |
+
|
| 525 |
+
# API call patterns to detect
|
| 526 |
+
API_PATTERNS = {
|
| 527 |
+
'fetch', # fetch('/api/users')
|
| 528 |
+
'axios', # axios.get('/api/users')
|
| 529 |
+
'$http', # Angular $http
|
| 530 |
+
'request', # request library
|
| 531 |
+
'superagent', # superagent library
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
def __init__(self):
|
| 535 |
+
self.reset()
|
| 536 |
+
|
| 537 |
+
def reset(self) -> None:
|
| 538 |
+
self.declared_entities: List[Dict[str, Any]] = []
|
| 539 |
+
self.called_entities: List[str] = []
|
| 540 |
+
self.scope_stack: List[str] = []
|
| 541 |
+
self.api_calls: List[Dict[str, Any]] = [] # Track API endpoint calls
|
| 542 |
+
|
| 543 |
+
def _qualified(self, name: str) -> str:
|
| 544 |
+
"""Return fully qualified name using current scope stack."""
|
| 545 |
+
if not name:
|
| 546 |
+
return ""
|
| 547 |
+
scope = ".".join(self.scope_stack)
|
| 548 |
+
return f"{scope}.{name}" if scope else name
|
| 549 |
+
|
| 550 |
+
def _get_function_name(self, node) -> Optional[str]:
|
| 551 |
+
"""Extract function name from various function node types."""
|
| 552 |
+
if hasattr(node, 'id') and node.id:
|
| 553 |
+
return node.id.name
|
| 554 |
+
return None
|
| 555 |
+
|
| 556 |
+
def _walk_node(self, node):
|
| 557 |
+
"""Recursively walk the AST and extract entities."""
|
| 558 |
+
if not node or not hasattr(node, 'type'):
|
| 559 |
+
return
|
| 560 |
+
|
| 561 |
+
node_type = node.type
|
| 562 |
+
|
| 563 |
+
# --- Function Declaration ---
|
| 564 |
+
if node_type == 'FunctionDeclaration':
|
| 565 |
+
func_name = self._get_function_name(node)
|
| 566 |
+
if func_name:
|
| 567 |
+
qualified = self._qualified(func_name)
|
| 568 |
+
self.declared_entities.append({"name": qualified, "type": "function"})
|
| 569 |
+
|
| 570 |
+
# Extract parameters
|
| 571 |
+
if hasattr(node, 'params'):
|
| 572 |
+
for param in node.params:
|
| 573 |
+
param_name = self._extract_pattern_name(param)
|
| 574 |
+
if param_name:
|
| 575 |
+
self.declared_entities.append({
|
| 576 |
+
"name": f"{qualified}.{param_name}",
|
| 577 |
+
"type": "variable",
|
| 578 |
+
"dtype": "unknown"
|
| 579 |
+
})
|
| 580 |
+
|
| 581 |
+
self.scope_stack.append(func_name)
|
| 582 |
+
if hasattr(node, 'body'):
|
| 583 |
+
self._walk_node(node.body)
|
| 584 |
+
self.scope_stack.pop()
|
| 585 |
+
|
| 586 |
+
# --- Arrow Function Expression ---
|
| 587 |
+
elif node_type == 'ArrowFunctionExpression':
|
| 588 |
+
# Arrow functions are typically assigned, handle in VariableDeclarator
|
| 589 |
+
if hasattr(node, 'params'):
|
| 590 |
+
for param in node.params:
|
| 591 |
+
param_name = self._extract_pattern_name(param)
|
| 592 |
+
# Note: can't fully qualify without parent context
|
| 593 |
+
if hasattr(node, 'body'):
|
| 594 |
+
self._walk_node(node.body)
|
| 595 |
+
|
| 596 |
+
# --- Function Expression ---
|
| 597 |
+
elif node_type == 'FunctionExpression':
|
| 598 |
+
func_name = self._get_function_name(node)
|
| 599 |
+
if func_name:
|
| 600 |
+
qualified = self._qualified(func_name)
|
| 601 |
+
self.declared_entities.append({"name": qualified, "type": "function"})
|
| 602 |
+
self.scope_stack.append(func_name)
|
| 603 |
+
|
| 604 |
+
if hasattr(node, 'params'):
|
| 605 |
+
for param in node.params:
|
| 606 |
+
param_name = self._extract_pattern_name(param)
|
| 607 |
+
if param_name and func_name:
|
| 608 |
+
self.declared_entities.append({
|
| 609 |
+
"name": f"{self._qualified(func_name)}.{param_name}",
|
| 610 |
+
"type": "variable",
|
| 611 |
+
"dtype": "unknown"
|
| 612 |
+
})
|
| 613 |
+
|
| 614 |
+
if hasattr(node, 'body'):
|
| 615 |
+
self._walk_node(node.body)
|
| 616 |
+
|
| 617 |
+
if func_name:
|
| 618 |
+
self.scope_stack.pop()
|
| 619 |
+
|
| 620 |
+
# --- Class Declaration ---
|
| 621 |
+
elif node_type == 'ClassDeclaration':
|
| 622 |
+
class_name = node.id.name if hasattr(node, 'id') and node.id else None
|
| 623 |
+
if class_name:
|
| 624 |
+
qualified = self._qualified(class_name)
|
| 625 |
+
self.declared_entities.append({"name": qualified, "type": "class"})
|
| 626 |
+
|
| 627 |
+
# Handle inheritance
|
| 628 |
+
if hasattr(node, 'superClass') and node.superClass:
|
| 629 |
+
if hasattr(node.superClass, 'name'):
|
| 630 |
+
self.called_entities.append(node.superClass.name)
|
| 631 |
+
|
| 632 |
+
self.scope_stack.append(class_name)
|
| 633 |
+
if hasattr(node, 'body') and hasattr(node.body, 'body'):
|
| 634 |
+
for method in node.body.body:
|
| 635 |
+
self._walk_node(method)
|
| 636 |
+
self.scope_stack.pop()
|
| 637 |
+
|
| 638 |
+
# --- Method Definition ---
|
| 639 |
+
elif node_type == 'MethodDefinition':
|
| 640 |
+
method_name = node.key.name if hasattr(node, 'key') and hasattr(node.key, 'name') else None
|
| 641 |
+
if method_name:
|
| 642 |
+
qualified = self._qualified(method_name)
|
| 643 |
+
self.declared_entities.append({"name": qualified, "type": "method"})
|
| 644 |
+
|
| 645 |
+
if hasattr(node, 'value') and hasattr(node.value, 'params'):
|
| 646 |
+
for param in node.value.params:
|
| 647 |
+
param_name = self._extract_pattern_name(param)
|
| 648 |
+
if param_name:
|
| 649 |
+
self.declared_entities.append({
|
| 650 |
+
"name": f"{qualified}.{param_name}",
|
| 651 |
+
"type": "variable",
|
| 652 |
+
"dtype": "unknown"
|
| 653 |
+
})
|
| 654 |
+
|
| 655 |
+
if hasattr(node, 'value'):
|
| 656 |
+
self._walk_node(node.value)
|
| 657 |
+
|
| 658 |
+
# --- Variable Declaration ---
|
| 659 |
+
elif node_type == 'VariableDeclaration':
|
| 660 |
+
if hasattr(node, 'declarations'):
|
| 661 |
+
for decl in node.declarations:
|
| 662 |
+
self._walk_node(decl)
|
| 663 |
+
|
| 664 |
+
# --- Variable Declarator ---
|
| 665 |
+
elif node_type == 'VariableDeclarator':
|
| 666 |
+
var_name = self._extract_pattern_name(node.id) if hasattr(node, 'id') else None
|
| 667 |
+
if var_name:
|
| 668 |
+
qualified = self._qualified(var_name)
|
| 669 |
+
|
| 670 |
+
# Check if it's a function assignment
|
| 671 |
+
if hasattr(node, 'init') and node.init:
|
| 672 |
+
if node.init.type in ('FunctionExpression', 'ArrowFunctionExpression'):
|
| 673 |
+
self.declared_entities.append({"name": qualified, "type": "function"})
|
| 674 |
+
self.scope_stack.append(var_name)
|
| 675 |
+
self._walk_node(node.init)
|
| 676 |
+
self.scope_stack.pop()
|
| 677 |
+
else:
|
| 678 |
+
self.declared_entities.append({
|
| 679 |
+
"name": qualified,
|
| 680 |
+
"type": "variable",
|
| 681 |
+
"dtype": "unknown"
|
| 682 |
+
})
|
| 683 |
+
self._walk_node(node.init)
|
| 684 |
+
else:
|
| 685 |
+
self.declared_entities.append({
|
| 686 |
+
"name": qualified,
|
| 687 |
+
"type": "variable",
|
| 688 |
+
"dtype": "unknown"
|
| 689 |
+
})
|
| 690 |
+
|
| 691 |
+
# --- Call Expression ---
|
| 692 |
+
elif node_type == 'CallExpression':
|
| 693 |
+
callee_name = self._extract_callee_name(node.callee) if hasattr(node, 'callee') else None
|
| 694 |
+
if callee_name:
|
| 695 |
+
self.called_entities.append(callee_name)
|
| 696 |
+
|
| 697 |
+
# Detect API endpoint calls
|
| 698 |
+
self._detect_api_call(node, callee_name)
|
| 699 |
+
|
| 700 |
+
# Walk arguments
|
| 701 |
+
if hasattr(node, 'arguments'):
|
| 702 |
+
for arg in node.arguments:
|
| 703 |
+
self._walk_node(arg)
|
| 704 |
+
|
| 705 |
+
# --- Member Expression ---
|
| 706 |
+
elif node_type == 'MemberExpression':
|
| 707 |
+
# Don't record as call, just traverse
|
| 708 |
+
if hasattr(node, 'object'):
|
| 709 |
+
self._walk_node(node.object)
|
| 710 |
+
if hasattr(node, 'property'):
|
| 711 |
+
self._walk_node(node.property)
|
| 712 |
+
|
| 713 |
+
# --- Import/Export ---
|
| 714 |
+
elif node_type == 'ImportDeclaration':
|
| 715 |
+
if hasattr(node, 'source') and hasattr(node.source, 'value'):
|
| 716 |
+
self.called_entities.append(node.source.value)
|
| 717 |
+
|
| 718 |
+
elif node_type == 'ExportNamedDeclaration':
|
| 719 |
+
if hasattr(node, 'declaration'):
|
| 720 |
+
self._walk_node(node.declaration)
|
| 721 |
+
|
| 722 |
+
elif node_type == 'ExportDefaultDeclaration':
|
| 723 |
+
if hasattr(node, 'declaration'):
|
| 724 |
+
self._walk_node(node.declaration)
|
| 725 |
+
|
| 726 |
+
# --- Recursive traversal for other nodes ---
|
| 727 |
+
else:
|
| 728 |
+
if hasattr(node, '__dict__'):
|
| 729 |
+
for attr, val in vars(node).items():
|
| 730 |
+
if isinstance(val, list):
|
| 731 |
+
for item in val:
|
| 732 |
+
if hasattr(item, 'type'):
|
| 733 |
+
self._walk_node(item)
|
| 734 |
+
elif hasattr(val, 'type'):
|
| 735 |
+
self._walk_node(val)
|
| 736 |
+
|
| 737 |
+
def _extract_pattern_name(self, pattern) -> Optional[str]:
|
| 738 |
+
"""Extract name from various pattern types (Identifier, ObjectPattern, etc.)."""
|
| 739 |
+
if not pattern:
|
| 740 |
+
return None
|
| 741 |
+
if hasattr(pattern, 'type'):
|
| 742 |
+
if pattern.type == 'Identifier':
|
| 743 |
+
return pattern.name if hasattr(pattern, 'name') else None
|
| 744 |
+
elif pattern.type == 'RestElement':
|
| 745 |
+
return self._extract_pattern_name(pattern.argument) if hasattr(pattern, 'argument') else None
|
| 746 |
+
return None
|
| 747 |
+
|
| 748 |
+
def _extract_callee_name(self, callee) -> Optional[str]:
|
| 749 |
+
"""Extract the name of the function being called."""
|
| 750 |
+
if not callee:
|
| 751 |
+
return None
|
| 752 |
+
|
| 753 |
+
if hasattr(callee, 'type'):
|
| 754 |
+
if callee.type == 'Identifier':
|
| 755 |
+
return callee.name if hasattr(callee, 'name') else None
|
| 756 |
+
elif callee.type == 'MemberExpression':
|
| 757 |
+
obj = self._extract_callee_name(callee.object) if hasattr(callee, 'object') else ""
|
| 758 |
+
prop = callee.property.name if hasattr(callee, 'property') and hasattr(callee.property, 'name') else ""
|
| 759 |
+
if obj and prop:
|
| 760 |
+
return f"{obj}.{prop}"
|
| 761 |
+
return prop or obj
|
| 762 |
+
return None
|
| 763 |
+
|
| 764 |
+
def _detect_api_call(self, call_node, callee_name: str):
|
| 765 |
+
"""
|
| 766 |
+
Detect API endpoint calls in JavaScript code.
|
| 767 |
+
Handles patterns like:
|
| 768 |
+
- fetch('/api/users')
|
| 769 |
+
- axios.get('/api/users')
|
| 770 |
+
- axios.post('/api/users', data)
|
| 771 |
+
- request.get('/api/users')
|
| 772 |
+
"""
|
| 773 |
+
if not callee_name or not hasattr(call_node, 'arguments'):
|
| 774 |
+
return
|
| 775 |
+
|
| 776 |
+
# Split callee name to check for patterns
|
| 777 |
+
parts = callee_name.split('.')
|
| 778 |
+
base = parts[0]
|
| 779 |
+
method = parts[-1].lower() if len(parts) > 1 else None
|
| 780 |
+
|
| 781 |
+
# Check if this is an API call
|
| 782 |
+
is_api_call = False
|
| 783 |
+
http_method = 'unknown'
|
| 784 |
+
|
| 785 |
+
# Pattern 1: fetch('/api/...')
|
| 786 |
+
if base == 'fetch':
|
| 787 |
+
is_api_call = True
|
| 788 |
+
http_method = 'GET' # Default for fetch
|
| 789 |
+
|
| 790 |
+
# Pattern 2: axios.get('/api/...'), request.post(...), etc.
|
| 791 |
+
elif base in self.API_PATTERNS and method in self.HTTP_METHODS:
|
| 792 |
+
is_api_call = True
|
| 793 |
+
http_method = method.upper()
|
| 794 |
+
|
| 795 |
+
# Pattern 3: axios('/api/...', {method: 'POST'})
|
| 796 |
+
elif base in self.API_PATTERNS and method is None:
|
| 797 |
+
is_api_call = True
|
| 798 |
+
http_method = 'GET' # Default
|
| 799 |
+
|
| 800 |
+
if not is_api_call:
|
| 801 |
+
return
|
| 802 |
+
|
| 803 |
+
# Extract the endpoint URL from arguments
|
| 804 |
+
if call_node.arguments:
|
| 805 |
+
first_arg = call_node.arguments[0]
|
| 806 |
+
endpoint = self._extract_string_literal(first_arg)
|
| 807 |
+
|
| 808 |
+
if endpoint:
|
| 809 |
+
# Store as a called entity with special type
|
| 810 |
+
self.called_entities.append(f"API:{http_method}:{endpoint}")
|
| 811 |
+
|
| 812 |
+
# Also track in api_calls for easier filtering
|
| 813 |
+
self.api_calls.append({
|
| 814 |
+
"endpoint": endpoint,
|
| 815 |
+
"method": http_method,
|
| 816 |
+
"type": "api_call"
|
| 817 |
+
})
|
| 818 |
+
|
| 819 |
+
def _extract_string_literal(self, node) -> Optional[str]:
|
| 820 |
+
"""Extract string value from a Literal/TemplateLiteral node."""
|
| 821 |
+
if not node or not hasattr(node, 'type'):
|
| 822 |
+
return None
|
| 823 |
+
|
| 824 |
+
if node.type == 'Literal' and isinstance(node.value, str):
|
| 825 |
+
return node.value
|
| 826 |
+
elif node.type == 'TemplateLiteral':
|
| 827 |
+
# For template literals, we try to extract the quasi parts
|
| 828 |
+
# e.g., `/api/${version}/users` -> /api/{version}/users
|
| 829 |
+
if hasattr(node, 'quasis'):
|
| 830 |
+
parts = []
|
| 831 |
+
for i, quasi in enumerate(node.quasis):
|
| 832 |
+
if hasattr(quasi, 'value') and hasattr(quasi.value, 'raw'):
|
| 833 |
+
parts.append(quasi.value.raw)
|
| 834 |
+
if i < len(node.quasis) - 1:
|
| 835 |
+
parts.append('{param}')
|
| 836 |
+
return ''.join(parts)
|
| 837 |
+
|
| 838 |
+
return None
|
| 839 |
+
|
| 840 |
+
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
|
| 841 |
+
self.reset()
|
| 842 |
+
|
| 843 |
+
try:
|
| 844 |
+
tree = esprima.parseScript(code, {'tolerant': True, 'loc': False})
|
| 845 |
+
except Exception as e:
|
| 846 |
+
# Try parsing as module if script fails
|
| 847 |
+
try:
|
| 848 |
+
tree = esprima.parseModule(code, {'tolerant': True, 'loc': False})
|
| 849 |
+
except Exception as e2:
|
| 850 |
+
logger.error(f"Failed to parse JavaScript code: {e2}")
|
| 851 |
+
return [], []
|
| 852 |
+
|
| 853 |
+
if hasattr(tree, 'body'):
|
| 854 |
+
for node in tree.body:
|
| 855 |
+
self._walk_node(node)
|
| 856 |
+
|
| 857 |
+
# Deduplicate
|
| 858 |
+
seen_decl = set()
|
| 859 |
+
unique_declared = []
|
| 860 |
+
for e in self.declared_entities:
|
| 861 |
+
key = (e.get("name"), e.get("type"), e.get("dtype"))
|
| 862 |
+
if key not in seen_decl:
|
| 863 |
+
unique_declared.append(e)
|
| 864 |
+
seen_decl.add(key)
|
| 865 |
+
|
| 866 |
+
unique_called = list(dict.fromkeys(self.called_entities))
|
| 867 |
+
return unique_declared, unique_called
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
class CEntityExtractor(BaseASTEntityExtractor):
|
| 871 |
+
"""
|
| 872 |
+
Extract declared and called entities from C code using clang.cindex (libclang),
|
| 873 |
+
with filtering to ignore system headers.
|
| 874 |
+
"""
|
| 875 |
+
|
| 876 |
+
def __init__(self):
|
| 877 |
+
self.index = cindex.Index.create()
|
| 878 |
+
|
| 879 |
+
def reset(self) -> None:
|
| 880 |
+
"""No persistent state to reset, but method provided for interface consistency."""
|
| 881 |
+
pass
|
| 882 |
+
|
| 883 |
+
def _walk_cursor(self, cursor, declared, called, source_file):
|
| 884 |
+
"""Recursively walk a clang Cursor, restricted to the main file."""
|
| 885 |
+
for c in cursor.get_children():
|
| 886 |
+
# --- Include directives ---
|
| 887 |
+
# Note: INCLUSION_DIRECTIVE nodes are at the root level and need special handling
|
| 888 |
+
if c.kind == cindex.CursorKind.INCLUSION_DIRECTIVE:
|
| 889 |
+
# Get the included file name
|
| 890 |
+
included_file = c.displayname
|
| 891 |
+
if included_file:
|
| 892 |
+
called.append(included_file)
|
| 893 |
+
continue
|
| 894 |
+
|
| 895 |
+
loc = c.location
|
| 896 |
+
if not loc.file or not source_file:
|
| 897 |
+
continue
|
| 898 |
+
|
| 899 |
+
# Skip system / external headers for other nodes
|
| 900 |
+
if os.path.abspath(loc.file.name) != os.path.abspath(source_file):
|
| 901 |
+
continue
|
| 902 |
+
|
| 903 |
+
# --- Declarations ---
|
| 904 |
+
if c.kind.is_declaration():
|
| 905 |
+
if c.kind in (cindex.CursorKind.FUNCTION_DECL, cindex.CursorKind.FUNCTION_TEMPLATE):
|
| 906 |
+
name = c.spelling or c.displayname
|
| 907 |
+
declared.append({"name": name, "type": "function"})
|
| 908 |
+
for p in c.get_arguments():
|
| 909 |
+
declared.append({
|
| 910 |
+
"name": f"{name}.{p.spelling}",
|
| 911 |
+
"type": "variable",
|
| 912 |
+
"dtype": p.type.spelling
|
| 913 |
+
})
|
| 914 |
+
elif c.kind == cindex.CursorKind.VAR_DECL:
|
| 915 |
+
declared.append({
|
| 916 |
+
"name": c.spelling,
|
| 917 |
+
"type": "variable",
|
| 918 |
+
"dtype": c.type.spelling
|
| 919 |
+
})
|
| 920 |
+
|
| 921 |
+
# Add the variable's type to called entities
|
| 922 |
+
# This captures struct references like "struct Point p;"
|
| 923 |
+
if c.type.spelling:
|
| 924 |
+
# Extract the base type name (remove const, &, *, struct keyword, etc.)
|
| 925 |
+
type_name = c.type.spelling.strip()
|
| 926 |
+
# Remove common qualifiers and keywords
|
| 927 |
+
type_name = type_name.replace('const', '').replace('&', '').replace('*', '').replace('struct', '').strip()
|
| 928 |
+
if type_name and not type_name in ['int', 'float', 'double', 'char', 'bool', 'void', 'long', 'short', 'unsigned', 'signed', 'size_t']:
|
| 929 |
+
called.append(type_name)
|
| 930 |
+
elif c.kind == cindex.CursorKind.STRUCT_DECL:
|
| 931 |
+
declared.append({"name": c.spelling or c.displayname, "type": "struct"})
|
| 932 |
+
elif c.kind == cindex.CursorKind.TYPEDEF_DECL:
|
| 933 |
+
declared.append({"name": c.spelling, "type": "typedef"})
|
| 934 |
+
|
| 935 |
+
# --- Calls ---
|
| 936 |
+
if c.kind == cindex.CursorKind.CALL_EXPR:
|
| 937 |
+
callee = None
|
| 938 |
+
for child in c.get_children():
|
| 939 |
+
if child.kind in (cindex.CursorKind.DECL_REF_EXPR, cindex.CursorKind.MEMBER_REF_EXPR):
|
| 940 |
+
callee = child.spelling
|
| 941 |
+
break
|
| 942 |
+
if callee:
|
| 943 |
+
called.append(callee)
|
| 944 |
+
else:
|
| 945 |
+
called.append(c.displayname or c.spelling)
|
| 946 |
+
|
| 947 |
+
# --- Recurse ---
|
| 948 |
+
self._walk_cursor(c, declared, called, source_file)
|
| 949 |
+
|
| 950 |
+
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
|
| 951 |
+
declared, called = [], []
|
| 952 |
+
|
| 953 |
+
# If file_path is provided, use it directly for better include resolution
|
| 954 |
+
# Otherwise, create a temporary file
|
| 955 |
+
tf_name = None
|
| 956 |
+
temp_file = False
|
| 957 |
+
|
| 958 |
+
if file_path and os.path.exists(file_path):
|
| 959 |
+
tf_name = file_path
|
| 960 |
+
temp_file = False
|
| 961 |
+
else:
|
| 962 |
+
with tempfile.NamedTemporaryFile(suffix=".c", mode="w+", delete=False) as tf:
|
| 963 |
+
tf_name = tf.name
|
| 964 |
+
tf.write(code)
|
| 965 |
+
tf.flush()
|
| 966 |
+
temp_file = True
|
| 967 |
+
|
| 968 |
+
# Get the directory containing the file for include paths
|
| 969 |
+
include_dir = os.path.dirname(tf_name) if tf_name else None
|
| 970 |
+
args = ['-std=c11']
|
| 971 |
+
if include_dir:
|
| 972 |
+
args.append(f'-I{include_dir}')
|
| 973 |
+
|
| 974 |
+
try:
|
| 975 |
+
tu = self.index.parse(
|
| 976 |
+
tf_name,
|
| 977 |
+
args=args,
|
| 978 |
+
options=cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD
|
| 979 |
+
)
|
| 980 |
+
except Exception as e:
|
| 981 |
+
raise RuntimeError(f"libclang failed to parse translation unit: {e}")
|
| 982 |
+
|
| 983 |
+
self._walk_cursor(tu.cursor, declared, called, tf_name)
|
| 984 |
+
|
| 985 |
+
# Deduplicate
|
| 986 |
+
seen_decl = set()
|
| 987 |
+
unique_declared = []
|
| 988 |
+
for e in declared:
|
| 989 |
+
key = (e.get("name"), e.get("type"), e.get("dtype", None))
|
| 990 |
+
if key not in seen_decl:
|
| 991 |
+
unique_declared.append(e)
|
| 992 |
+
seen_decl.add(key)
|
| 993 |
+
|
| 994 |
+
unique_called = list(dict.fromkeys(called))
|
| 995 |
+
|
| 996 |
+
# Only delete if we created a temp file
|
| 997 |
+
if temp_file:
|
| 998 |
+
try:
|
| 999 |
+
os.unlink(tf_name)
|
| 1000 |
+
except Exception:
|
| 1001 |
+
pass
|
| 1002 |
+
|
| 1003 |
+
return unique_declared, unique_called
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
class CppEntityExtractor(BaseASTEntityExtractor):
|
| 1007 |
+
"""
|
| 1008 |
+
Extract declared and called entities from C++ code using clang.cindex (libclang),
|
| 1009 |
+
including classes, namespaces, and methods.
|
| 1010 |
+
"""
|
| 1011 |
+
|
| 1012 |
+
def __init__(self):
|
| 1013 |
+
self.index = cindex.Index.create()
|
| 1014 |
+
self.reset()
|
| 1015 |
+
|
| 1016 |
+
def reset(self) -> None:
|
| 1017 |
+
self.declared_entities = []
|
| 1018 |
+
self.called_entities = []
|
| 1019 |
+
self.scope_stack = []
|
| 1020 |
+
|
| 1021 |
+
def _qualified(self, name: str) -> str:
|
| 1022 |
+
"""Return fully qualified name using current scope stack."""
|
| 1023 |
+
if not name:
|
| 1024 |
+
return ""
|
| 1025 |
+
if not self.scope_stack:
|
| 1026 |
+
return name
|
| 1027 |
+
return "::".join(self.scope_stack + [name])
|
| 1028 |
+
|
| 1029 |
+
def _walk_cursor(self, cursor, source_file: str):
|
| 1030 |
+
for c in cursor.get_children():
|
| 1031 |
+
# --- Include directives ---
|
| 1032 |
+
# Note: INCLUSION_DIRECTIVE nodes are at the root level and need special handling
|
| 1033 |
+
if c.kind == cindex.CursorKind.INCLUSION_DIRECTIVE:
|
| 1034 |
+
# Get the included file name
|
| 1035 |
+
included_file = c.displayname
|
| 1036 |
+
if included_file:
|
| 1037 |
+
self.called_entities.append(included_file)
|
| 1038 |
+
continue
|
| 1039 |
+
|
| 1040 |
+
kind = c.kind
|
| 1041 |
+
|
| 1042 |
+
# --- Namespace --- (process before location check)
|
| 1043 |
+
if kind == cindex.CursorKind.NAMESPACE:
|
| 1044 |
+
if c.spelling: # Only add non-empty namespace names
|
| 1045 |
+
self.scope_stack.append(c.spelling)
|
| 1046 |
+
self._walk_cursor(c, source_file)
|
| 1047 |
+
if c.spelling:
|
| 1048 |
+
self.scope_stack.pop()
|
| 1049 |
+
continue
|
| 1050 |
+
|
| 1051 |
+
# Check location for other node types
|
| 1052 |
+
loc = c.location
|
| 1053 |
+
# Skip nodes from other files, but allow nodes without location info
|
| 1054 |
+
if loc.file and os.path.abspath(loc.file.name) != os.path.abspath(source_file):
|
| 1055 |
+
continue
|
| 1056 |
+
|
| 1057 |
+
# --- Class / Struct ---
|
| 1058 |
+
if kind in (cindex.CursorKind.CLASS_DECL, cindex.CursorKind.STRUCT_DECL):
|
| 1059 |
+
# Only process if it has a name
|
| 1060 |
+
if c.spelling:
|
| 1061 |
+
# Check if it's a definition (not a forward declaration)
|
| 1062 |
+
is_def = c.is_definition() if hasattr(c, 'is_definition') else True
|
| 1063 |
+
if is_def:
|
| 1064 |
+
full_name = self._qualified(c.spelling)
|
| 1065 |
+
self.declared_entities.append({"name": full_name, "type": "class"})
|
| 1066 |
+
|
| 1067 |
+
# Handle base classes (inheritance)
|
| 1068 |
+
for base in c.get_children():
|
| 1069 |
+
if base.kind == cindex.CursorKind.CXX_BASE_SPECIFIER:
|
| 1070 |
+
if base.spelling:
|
| 1071 |
+
self.called_entities.append(base.spelling)
|
| 1072 |
+
|
| 1073 |
+
self.scope_stack.append(c.spelling)
|
| 1074 |
+
self._walk_cursor(c, source_file)
|
| 1075 |
+
self.scope_stack.pop()
|
| 1076 |
+
continue
|
| 1077 |
+
|
| 1078 |
+
# --- Methods ---
|
| 1079 |
+
if kind in (cindex.CursorKind.CXX_METHOD, cindex.CursorKind.CONSTRUCTOR, cindex.CursorKind.DESTRUCTOR):
|
| 1080 |
+
if c.spelling: # Only process if it has a name
|
| 1081 |
+
full_name = self._qualified(c.spelling)
|
| 1082 |
+
self.declared_entities.append({"name": full_name, "type": "method"})
|
| 1083 |
+
|
| 1084 |
+
for p in c.get_arguments():
|
| 1085 |
+
if p.spelling: # Only add parameters with names
|
| 1086 |
+
self.declared_entities.append({
|
| 1087 |
+
"name": f"{full_name}.{p.spelling}",
|
| 1088 |
+
"type": "variable",
|
| 1089 |
+
"dtype": p.type.spelling
|
| 1090 |
+
})
|
| 1091 |
+
|
| 1092 |
+
self._walk_cursor(c, source_file)
|
| 1093 |
+
continue
|
| 1094 |
+
|
| 1095 |
+
# --- Free functions ---
|
| 1096 |
+
if kind == cindex.CursorKind.FUNCTION_DECL:
|
| 1097 |
+
if c.spelling: # Only process if it has a name
|
| 1098 |
+
full_name = self._qualified(c.spelling)
|
| 1099 |
+
self.declared_entities.append({"name": full_name, "type": "function"})
|
| 1100 |
+
for p in c.get_arguments():
|
| 1101 |
+
if p.spelling: # Only add parameters with names
|
| 1102 |
+
self.declared_entities.append({
|
| 1103 |
+
"name": f"{full_name}.{p.spelling}",
|
| 1104 |
+
"type": "variable",
|
| 1105 |
+
"dtype": p.type.spelling
|
| 1106 |
+
})
|
| 1107 |
+
self._walk_cursor(c, source_file)
|
| 1108 |
+
continue
|
| 1109 |
+
|
| 1110 |
+
# --- Variables ---
|
| 1111 |
+
if kind == cindex.CursorKind.VAR_DECL:
|
| 1112 |
+
full_name = self._qualified(c.spelling)
|
| 1113 |
+
self.declared_entities.append({
|
| 1114 |
+
"name": full_name,
|
| 1115 |
+
"type": "variable",
|
| 1116 |
+
"dtype": c.type.spelling
|
| 1117 |
+
})
|
| 1118 |
+
|
| 1119 |
+
# Look for TYPE_REF children which explicitly reference the type
|
| 1120 |
+
# This is more reliable than c.type.spelling when includes aren't resolved
|
| 1121 |
+
type_ref_found = False
|
| 1122 |
+
for child in c.get_children():
|
| 1123 |
+
if child.kind == cindex.CursorKind.TYPE_REF:
|
| 1124 |
+
# TYPE_REF.spelling gives us the fully qualified type name
|
| 1125 |
+
# It may have 'class ' or 'struct ' prefix, so strip it
|
| 1126 |
+
if child.spelling:
|
| 1127 |
+
type_name = child.spelling.replace('class ', '').replace('struct ', '').strip()
|
| 1128 |
+
if type_name:
|
| 1129 |
+
# TYPE_REF gives us the canonical name from the definition,
|
| 1130 |
+
# which includes namespace qualifiers if present.
|
| 1131 |
+
# We only add this canonical name and rely on alias resolution
|
| 1132 |
+
# to match unqualified usage (e.g., 'Calculator' -> 'math::Calculator')
|
| 1133 |
+
self.called_entities.append(type_name)
|
| 1134 |
+
type_ref_found = True
|
| 1135 |
+
break
|
| 1136 |
+
|
| 1137 |
+
# Fallback: use c.type.spelling if no TYPE_REF found
|
| 1138 |
+
# Note: c.type.spelling may give us the name as written in source code,
|
| 1139 |
+
# which could be unqualified even if it refers to a namespaced type
|
| 1140 |
+
if not type_ref_found and c.type.spelling:
|
| 1141 |
+
# Extract the base type name (remove const, &, *, etc.)
|
| 1142 |
+
type_name = c.type.spelling.strip()
|
| 1143 |
+
# Remove common qualifiers
|
| 1144 |
+
type_name = type_name.replace('const', '').replace('&', '').replace('*', '').strip()
|
| 1145 |
+
if type_name and not type_name in ['int', 'float', 'double', 'char', 'bool', 'void', 'long', 'short', 'unsigned', 'signed']:
|
| 1146 |
+
# Only add if not already added via TYPE_REF
|
| 1147 |
+
# c.type.spelling might give unqualified name even for namespaced types
|
| 1148 |
+
# We'll add it and let alias resolution handle it
|
| 1149 |
+
self.called_entities.append(type_name)
|
| 1150 |
+
|
| 1151 |
+
# --- Calls ---
|
| 1152 |
+
if kind == cindex.CursorKind.CALL_EXPR:
|
| 1153 |
+
callee = None
|
| 1154 |
+
for child in c.get_children():
|
| 1155 |
+
if child.kind in (cindex.CursorKind.DECL_REF_EXPR, cindex.CursorKind.MEMBER_REF_EXPR):
|
| 1156 |
+
callee = child.spelling
|
| 1157 |
+
break
|
| 1158 |
+
if callee:
|
| 1159 |
+
self.called_entities.append(callee)
|
| 1160 |
+
else:
|
| 1161 |
+
self.called_entities.append(c.displayname or c.spelling)
|
| 1162 |
+
|
| 1163 |
+
# Recurse
|
| 1164 |
+
self._walk_cursor(c, source_file)
|
| 1165 |
+
|
| 1166 |
+
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
|
| 1167 |
+
self.reset()
|
| 1168 |
+
|
| 1169 |
+
# If file_path is provided, use it directly for better include resolution
|
| 1170 |
+
# Otherwise, create a temporary file
|
| 1171 |
+
tf_name = None
|
| 1172 |
+
temp_file = False
|
| 1173 |
+
|
| 1174 |
+
if file_path and os.path.exists(file_path):
|
| 1175 |
+
tf_name = file_path
|
| 1176 |
+
temp_file = False
|
| 1177 |
+
else:
|
| 1178 |
+
with tempfile.NamedTemporaryFile(suffix=".cpp", mode="w+", delete=False) as tf:
|
| 1179 |
+
tf_name = tf.name
|
| 1180 |
+
tf.write(code)
|
| 1181 |
+
tf.flush()
|
| 1182 |
+
temp_file = True
|
| 1183 |
+
|
| 1184 |
+
# Get the directory containing the file for include paths
|
| 1185 |
+
include_dir = os.path.dirname(tf_name) if tf_name else None
|
| 1186 |
+
args = ['-std=c++17', '-xc++']
|
| 1187 |
+
if include_dir:
|
| 1188 |
+
args.append(f'-I{include_dir}')
|
| 1189 |
+
|
| 1190 |
+
try:
|
| 1191 |
+
tu = self.index.parse(
|
| 1192 |
+
tf_name,
|
| 1193 |
+
args=args,
|
| 1194 |
+
options=cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD
|
| 1195 |
+
)
|
| 1196 |
+
except Exception as e:
|
| 1197 |
+
raise RuntimeError(f"libclang failed to parse C++ translation unit: {e}")
|
| 1198 |
+
|
| 1199 |
+
self._walk_cursor(tu.cursor, tf_name)
|
| 1200 |
+
|
| 1201 |
+
# Deduplicate
|
| 1202 |
+
seen_decl = set()
|
| 1203 |
+
unique_declared = []
|
| 1204 |
+
for e in self.declared_entities:
|
| 1205 |
+
key = (e.get("name"), e.get("type"), e.get("dtype", None))
|
| 1206 |
+
if key not in seen_decl:
|
| 1207 |
+
unique_declared.append(e)
|
| 1208 |
+
seen_decl.add(key)
|
| 1209 |
+
|
| 1210 |
+
unique_called = list(dict.fromkeys(self.called_entities))
|
| 1211 |
+
|
| 1212 |
+
# Only delete if we created a temp file
|
| 1213 |
+
if temp_file:
|
| 1214 |
+
try:
|
| 1215 |
+
os.unlink(tf_name)
|
| 1216 |
+
except Exception:
|
| 1217 |
+
pass
|
| 1218 |
+
|
| 1219 |
+
return unique_declared, unique_called
|
| 1220 |
+
|
| 1221 |
+
|
| 1222 |
+
class RustEntityExtractor(BaseASTEntityExtractor):
|
| 1223 |
+
"""
|
| 1224 |
+
Extract declared and called entities from Rust code using tree-sitter.
|
| 1225 |
+
Handles structs, enums, traits, functions, methods, and modules.
|
| 1226 |
+
Also detects API endpoint definitions (Actix-web, Rocket, Axum, Warp).
|
| 1227 |
+
"""
|
| 1228 |
+
|
| 1229 |
+
# HTTP method route macros for Rust web frameworks
|
| 1230 |
+
ROUTE_MACROS = {
|
| 1231 |
+
'get', 'post', 'put', 'patch', 'delete', 'head', 'options', # Actix-web, Rocket
|
| 1232 |
+
'Get', 'Post', 'Put', 'Patch', 'Delete', 'Head', 'Options', # Alternative casing
|
| 1233 |
+
}
|
| 1234 |
+
|
| 1235 |
+
# Route-related macros and functions
|
| 1236 |
+
ROUTE_PATTERNS = {
|
| 1237 |
+
'route', # Generic route macro
|
| 1238 |
+
'web::get', 'web::post', 'web::put', 'web::delete', # Actix-web with web::
|
| 1239 |
+
}
|
| 1240 |
+
|
| 1241 |
+
def __init__(self):
|
| 1242 |
+
|
| 1243 |
+
self.parser = Parser()
|
| 1244 |
+
self.parser.language = Language(ts_rust.language())
|
| 1245 |
+
self.reset()
|
| 1246 |
+
|
| 1247 |
+
def reset(self) -> None:
|
| 1248 |
+
self.declared_entities = []
|
| 1249 |
+
self.called_entities = []
|
| 1250 |
+
self.scope_stack = []
|
| 1251 |
+
self.api_endpoints: List[Dict[str, Any]] = [] # Track API endpoint definitions
|
| 1252 |
+
|
| 1253 |
+
def _qualified(self, name: str) -> str:
|
| 1254 |
+
"""Return fully qualified name using current scope stack."""
|
| 1255 |
+
if not name:
|
| 1256 |
+
return ""
|
| 1257 |
+
if not self.scope_stack:
|
| 1258 |
+
return name
|
| 1259 |
+
return "::".join(self.scope_stack + [name])
|
| 1260 |
+
|
| 1261 |
+
def _get_node_text(self, node, code_bytes: bytes) -> str:
|
| 1262 |
+
"""Extract text content of a node."""
|
| 1263 |
+
return code_bytes[node.start_byte:node.end_byte].decode('utf8')
|
| 1264 |
+
|
| 1265 |
+
def _extract_api_endpoint_from_attributes(self, node, code_bytes: bytes) -> Optional[Dict[str, Any]]:
|
| 1266 |
+
"""
|
| 1267 |
+
Extract API endpoint information from Rust function attributes.
|
| 1268 |
+
Handles patterns like:
|
| 1269 |
+
- #[get("/users")] # Actix-web, Rocket
|
| 1270 |
+
- #[post("/users")] # Actix-web, Rocket
|
| 1271 |
+
- #[route("/users", method="GET")] # Generic route
|
| 1272 |
+
|
| 1273 |
+
Note: In tree-sitter Rust AST, attributes appear as PREVIOUS SIBLINGS
|
| 1274 |
+
of the function_item node, not as children.
|
| 1275 |
+
"""
|
| 1276 |
+
|
| 1277 |
+
|
| 1278 |
+
# Get the parent node to access siblings
|
| 1279 |
+
parent = node.parent
|
| 1280 |
+
if not parent:
|
| 1281 |
+
return None
|
| 1282 |
+
|
| 1283 |
+
# Find the index of current node in parent's children
|
| 1284 |
+
node_index = None
|
| 1285 |
+
for i, child in enumerate(parent.children):
|
| 1286 |
+
if child == node:
|
| 1287 |
+
node_index = i
|
| 1288 |
+
break
|
| 1289 |
+
|
| 1290 |
+
if node_index is None:
|
| 1291 |
+
return None
|
| 1292 |
+
|
| 1293 |
+
# Look backwards through previous siblings for attribute_item nodes
|
| 1294 |
+
for i in range(node_index - 1, -1, -1):
|
| 1295 |
+
sibling = parent.children[i]
|
| 1296 |
+
|
| 1297 |
+
# Stop if we hit a non-attribute node (except comments/whitespace)
|
| 1298 |
+
if sibling.type not in ['attribute_item', 'line_comment', 'block_comment']:
|
| 1299 |
+
break
|
| 1300 |
+
|
| 1301 |
+
if sibling.type == 'attribute_item':
|
| 1302 |
+
attr_text = self._get_node_text(sibling, code_bytes)
|
| 1303 |
+
|
| 1304 |
+
# Match HTTP method macros: #[get("/path")], #[post("/path")], #[post("/path", data = "<var>")], etc.
|
| 1305 |
+
# The pattern now allows optional additional parameters after the path
|
| 1306 |
+
method_pattern = r'#\[(get|post|put|patch|delete|head|options)\s*\(\s*"([^"]+)"(?:\s*,.*?)?\s*\)\]'
|
| 1307 |
+
match = re.search(method_pattern, attr_text, re.IGNORECASE)
|
| 1308 |
+
|
| 1309 |
+
if match:
|
| 1310 |
+
http_method = match.group(1).upper()
|
| 1311 |
+
endpoint_path = match.group(2)
|
| 1312 |
+
return {
|
| 1313 |
+
"endpoint": endpoint_path,
|
| 1314 |
+
"methods": [http_method],
|
| 1315 |
+
"type": "api_endpoint_definition"
|
| 1316 |
+
}
|
| 1317 |
+
|
| 1318 |
+
# Match generic route macro: #[route("/path", method="GET")]
|
| 1319 |
+
route_pattern = r'#\[route\s*\(\s*"([^"]+)"(?:.*?method\s*=\s*"([^"]+)")?\s*\)\]'
|
| 1320 |
+
match = re.search(route_pattern, attr_text, re.IGNORECASE)
|
| 1321 |
+
|
| 1322 |
+
if match:
|
| 1323 |
+
endpoint_path = match.group(1)
|
| 1324 |
+
http_method = match.group(2).upper() if match.group(2) else "GET"
|
| 1325 |
+
return {
|
| 1326 |
+
"endpoint": endpoint_path,
|
| 1327 |
+
"methods": [http_method],
|
| 1328 |
+
"type": "api_endpoint_definition"
|
| 1329 |
+
}
|
| 1330 |
+
|
| 1331 |
+
return None
|
| 1332 |
+
|
| 1333 |
+
def _walk_tree(self, node, code_bytes: bytes):
|
| 1334 |
+
"""Recursively walk the tree-sitter AST."""
|
| 1335 |
+
node_type = node.type
|
| 1336 |
+
|
| 1337 |
+
# --- Module declarations ---
|
| 1338 |
+
if node_type == 'mod_item':
|
| 1339 |
+
# mod my_module { ... }
|
| 1340 |
+
name_node = node.child_by_field_name('name')
|
| 1341 |
+
if name_node:
|
| 1342 |
+
mod_name = self._get_node_text(name_node, code_bytes)
|
| 1343 |
+
qualified = self._qualified(mod_name)
|
| 1344 |
+
self.declared_entities.append({"name": qualified, "type": "module"})
|
| 1345 |
+
|
| 1346 |
+
self.scope_stack.append(mod_name)
|
| 1347 |
+
body = node.child_by_field_name('body')
|
| 1348 |
+
if body:
|
| 1349 |
+
for child in body.children:
|
| 1350 |
+
self._walk_tree(child, code_bytes)
|
| 1351 |
+
self.scope_stack.pop()
|
| 1352 |
+
return
|
| 1353 |
+
|
| 1354 |
+
# --- Struct declarations ---
|
| 1355 |
+
elif node_type == 'struct_item':
|
| 1356 |
+
name_node = node.child_by_field_name('name')
|
| 1357 |
+
if name_node:
|
| 1358 |
+
struct_name = self._get_node_text(name_node, code_bytes)
|
| 1359 |
+
qualified = self._qualified(struct_name)
|
| 1360 |
+
self.declared_entities.append({"name": qualified, "type": "struct"})
|
| 1361 |
+
|
| 1362 |
+
# Check for generic parameters
|
| 1363 |
+
type_params = node.child_by_field_name('type_parameters')
|
| 1364 |
+
if type_params:
|
| 1365 |
+
self._walk_tree(type_params, code_bytes)
|
| 1366 |
+
|
| 1367 |
+
self.scope_stack.append(struct_name)
|
| 1368 |
+
# Process fields
|
| 1369 |
+
body = node.child_by_field_name('body')
|
| 1370 |
+
if body:
|
| 1371 |
+
for child in body.children:
|
| 1372 |
+
if child.type == 'field_declaration':
|
| 1373 |
+
field_name_node = child.child_by_field_name('name')
|
| 1374 |
+
field_type_node = child.child_by_field_name('type')
|
| 1375 |
+
if field_name_node:
|
| 1376 |
+
field_name = self._get_node_text(field_name_node, code_bytes)
|
| 1377 |
+
field_type = self._get_node_text(field_type_node, code_bytes) if field_type_node else "unknown"
|
| 1378 |
+
self.declared_entities.append({
|
| 1379 |
+
"name": f"{qualified}.{field_name}",
|
| 1380 |
+
"type": "field",
|
| 1381 |
+
"dtype": field_type
|
| 1382 |
+
})
|
| 1383 |
+
self.scope_stack.pop()
|
| 1384 |
+
return
|
| 1385 |
+
|
| 1386 |
+
# --- Enum declarations ---
|
| 1387 |
+
elif node_type == 'enum_item':
|
| 1388 |
+
name_node = node.child_by_field_name('name')
|
| 1389 |
+
if name_node:
|
| 1390 |
+
enum_name = self._get_node_text(name_node, code_bytes)
|
| 1391 |
+
qualified = self._qualified(enum_name)
|
| 1392 |
+
self.declared_entities.append({"name": qualified, "type": "enum"})
|
| 1393 |
+
|
| 1394 |
+
self.scope_stack.append(enum_name)
|
| 1395 |
+
body = node.child_by_field_name('body')
|
| 1396 |
+
if body:
|
| 1397 |
+
for child in body.children:
|
| 1398 |
+
if child.type == 'enum_variant':
|
| 1399 |
+
variant_name_node = child.child_by_field_name('name')
|
| 1400 |
+
if variant_name_node:
|
| 1401 |
+
variant_name = self._get_node_text(variant_name_node, code_bytes)
|
| 1402 |
+
self.declared_entities.append({
|
| 1403 |
+
"name": f"{qualified}::{variant_name}",
|
| 1404 |
+
"type": "enum_variant"
|
| 1405 |
+
})
|
| 1406 |
+
self.scope_stack.pop()
|
| 1407 |
+
return
|
| 1408 |
+
|
| 1409 |
+
# --- Trait declarations ---
|
| 1410 |
+
elif node_type == 'trait_item':
|
| 1411 |
+
name_node = node.child_by_field_name('name')
|
| 1412 |
+
if name_node:
|
| 1413 |
+
trait_name = self._get_node_text(name_node, code_bytes)
|
| 1414 |
+
qualified = self._qualified(trait_name)
|
| 1415 |
+
self.declared_entities.append({"name": qualified, "type": "trait"})
|
| 1416 |
+
|
| 1417 |
+
self.scope_stack.append(trait_name)
|
| 1418 |
+
body = node.child_by_field_name('body')
|
| 1419 |
+
if body:
|
| 1420 |
+
for child in body.children:
|
| 1421 |
+
self._walk_tree(child, code_bytes)
|
| 1422 |
+
self.scope_stack.pop()
|
| 1423 |
+
return
|
| 1424 |
+
|
| 1425 |
+
# --- Implementation blocks ---
|
| 1426 |
+
elif node_type == 'impl_item':
|
| 1427 |
+
# impl MyStruct { ... } or impl Trait for MyStruct { ... }
|
| 1428 |
+
type_node = node.child_by_field_name('type')
|
| 1429 |
+
trait_node = node.child_by_field_name('trait')
|
| 1430 |
+
|
| 1431 |
+
impl_name = None
|
| 1432 |
+
if type_node:
|
| 1433 |
+
impl_name = self._get_node_text(type_node, code_bytes)
|
| 1434 |
+
|
| 1435 |
+
if trait_node:
|
| 1436 |
+
trait_name = self._get_node_text(trait_node, code_bytes)
|
| 1437 |
+
self.called_entities.append(trait_name)
|
| 1438 |
+
|
| 1439 |
+
if impl_name:
|
| 1440 |
+
self.scope_stack.append(impl_name)
|
| 1441 |
+
|
| 1442 |
+
body = node.child_by_field_name('body')
|
| 1443 |
+
if body:
|
| 1444 |
+
for child in body.children:
|
| 1445 |
+
self._walk_tree(child, code_bytes)
|
| 1446 |
+
|
| 1447 |
+
if impl_name:
|
| 1448 |
+
self.scope_stack.pop()
|
| 1449 |
+
return
|
| 1450 |
+
|
| 1451 |
+
# --- Function declarations ---
|
| 1452 |
+
elif node_type == 'function_item':
|
| 1453 |
+
name_node = node.child_by_field_name('name')
|
| 1454 |
+
if name_node:
|
| 1455 |
+
func_name = self._get_node_text(name_node, code_bytes)
|
| 1456 |
+
qualified = self._qualified(func_name)
|
| 1457 |
+
|
| 1458 |
+
# Check for API endpoint attributes (e.g., #[get("/users")])
|
| 1459 |
+
api_info = self._extract_api_endpoint_from_attributes(node, code_bytes)
|
| 1460 |
+
|
| 1461 |
+
if api_info:
|
| 1462 |
+
# This is an API endpoint handler
|
| 1463 |
+
self.declared_entities.append({
|
| 1464 |
+
"name": qualified,
|
| 1465 |
+
"type": "api_endpoint",
|
| 1466 |
+
"endpoint": api_info.get("endpoint"),
|
| 1467 |
+
"methods": api_info.get("methods")
|
| 1468 |
+
})
|
| 1469 |
+
self.api_endpoints.append({**api_info, "function": qualified})
|
| 1470 |
+
entity_type = "api_endpoint"
|
| 1471 |
+
else:
|
| 1472 |
+
# Determine if this is a method (inside impl block) or free function
|
| 1473 |
+
entity_type = "method" if len(self.scope_stack) > 0 else "function"
|
| 1474 |
+
self.declared_entities.append({"name": qualified, "type": entity_type})
|
| 1475 |
+
|
| 1476 |
+
# Extract parameters
|
| 1477 |
+
params = node.child_by_field_name('parameters')
|
| 1478 |
+
if params:
|
| 1479 |
+
for child in params.children:
|
| 1480 |
+
if child.type == 'parameter':
|
| 1481 |
+
pattern = child.child_by_field_name('pattern')
|
| 1482 |
+
type_node = child.child_by_field_name('type')
|
| 1483 |
+
if pattern:
|
| 1484 |
+
param_name = self._get_node_text(pattern, code_bytes)
|
| 1485 |
+
param_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown"
|
| 1486 |
+
# Skip 'self' parameters
|
| 1487 |
+
if param_name not in ['self', '&self', '&mut self', 'mut self']:
|
| 1488 |
+
self.declared_entities.append({
|
| 1489 |
+
"name": f"{qualified}.{param_name}",
|
| 1490 |
+
"type": "variable",
|
| 1491 |
+
"dtype": param_type
|
| 1492 |
+
})
|
| 1493 |
+
|
| 1494 |
+
# Walk the function body to find calls
|
| 1495 |
+
body = node.child_by_field_name('body')
|
| 1496 |
+
if body:
|
| 1497 |
+
self._walk_tree(body, code_bytes)
|
| 1498 |
+
return
|
| 1499 |
+
|
| 1500 |
+
# --- Type alias ---
|
| 1501 |
+
elif node_type == 'type_item':
|
| 1502 |
+
name_node = node.child_by_field_name('name')
|
| 1503 |
+
if name_node:
|
| 1504 |
+
type_name = self._get_node_text(name_node, code_bytes)
|
| 1505 |
+
qualified = self._qualified(type_name)
|
| 1506 |
+
self.declared_entities.append({"name": qualified, "type": "type_alias"})
|
| 1507 |
+
return
|
| 1508 |
+
|
| 1509 |
+
# --- Constant declarations ---
|
| 1510 |
+
elif node_type == 'const_item':
|
| 1511 |
+
name_node = node.child_by_field_name('name')
|
| 1512 |
+
type_node = node.child_by_field_name('type')
|
| 1513 |
+
if name_node:
|
| 1514 |
+
const_name = self._get_node_text(name_node, code_bytes)
|
| 1515 |
+
const_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown"
|
| 1516 |
+
qualified = self._qualified(const_name)
|
| 1517 |
+
self.declared_entities.append({
|
| 1518 |
+
"name": qualified,
|
| 1519 |
+
"type": "constant",
|
| 1520 |
+
"dtype": const_type
|
| 1521 |
+
})
|
| 1522 |
+
|
| 1523 |
+
# --- Static declarations ---
|
| 1524 |
+
elif node_type == 'static_item':
|
| 1525 |
+
name_node = node.child_by_field_name('name')
|
| 1526 |
+
type_node = node.child_by_field_name('type')
|
| 1527 |
+
if name_node:
|
| 1528 |
+
static_name = self._get_node_text(name_node, code_bytes)
|
| 1529 |
+
static_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown"
|
| 1530 |
+
qualified = self._qualified(static_name)
|
| 1531 |
+
self.declared_entities.append({
|
| 1532 |
+
"name": qualified,
|
| 1533 |
+
"type": "static",
|
| 1534 |
+
"dtype": static_type
|
| 1535 |
+
})
|
| 1536 |
+
|
| 1537 |
+
# --- Let bindings (local variables) ---
|
| 1538 |
+
elif node_type == 'let_declaration':
|
| 1539 |
+
pattern = node.child_by_field_name('pattern')
|
| 1540 |
+
type_node = node.child_by_field_name('type')
|
| 1541 |
+
if pattern and pattern.type == 'identifier':
|
| 1542 |
+
var_name = self._get_node_text(pattern, code_bytes)
|
| 1543 |
+
var_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown"
|
| 1544 |
+
# Only track top-level or module-level variables, not function-local ones
|
| 1545 |
+
# For now, we skip local variables to avoid clutter
|
| 1546 |
+
|
| 1547 |
+
# --- Use declarations (imports) ---
|
| 1548 |
+
elif node_type == 'use_declaration':
|
| 1549 |
+
# Extract imported items
|
| 1550 |
+
use_text = self._get_node_text(node, code_bytes)
|
| 1551 |
+
self.called_entities.append(use_text)
|
| 1552 |
+
|
| 1553 |
+
# --- Call expressions ---
|
| 1554 |
+
elif node_type == 'call_expression':
|
| 1555 |
+
function = node.child_by_field_name('function')
|
| 1556 |
+
if function:
|
| 1557 |
+
func_text = self._get_node_text(function, code_bytes)
|
| 1558 |
+
# Clean up function call to get just the name/path
|
| 1559 |
+
# Handle method calls like obj.method() and path calls like std::vec::Vec::new()
|
| 1560 |
+
self.called_entities.append(func_text)
|
| 1561 |
+
|
| 1562 |
+
# --- Macro invocations ---
|
| 1563 |
+
elif node_type == 'macro_invocation':
|
| 1564 |
+
macro_node = node.child_by_field_name('macro')
|
| 1565 |
+
if macro_node:
|
| 1566 |
+
macro_name = self._get_node_text(macro_node, code_bytes)
|
| 1567 |
+
self.called_entities.append(f"{macro_name}!")
|
| 1568 |
+
|
| 1569 |
+
# --- Field expressions (method calls or field access) ---
|
| 1570 |
+
elif node_type == 'field_expression':
|
| 1571 |
+
field = node.child_by_field_name('field')
|
| 1572 |
+
if field:
|
| 1573 |
+
field_name = self._get_node_text(field, code_bytes)
|
| 1574 |
+
# This could be a field access or method call, record it
|
| 1575 |
+
# We don't have full context here, so just record the field name
|
| 1576 |
+
|
| 1577 |
+
# Recursively walk all children
|
| 1578 |
+
for child in node.children:
|
| 1579 |
+
self._walk_tree(child, code_bytes)
|
| 1580 |
+
|
| 1581 |
+
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
|
| 1582 |
+
"""Extract entities from Rust code using tree-sitter."""
|
| 1583 |
+
self.reset()
|
| 1584 |
+
|
| 1585 |
+
code_bytes = code.encode('utf8')
|
| 1586 |
+
tree = self.parser.parse(code_bytes)
|
| 1587 |
+
|
| 1588 |
+
# Walk the AST
|
| 1589 |
+
self._walk_tree(tree.root_node, code_bytes)
|
| 1590 |
+
|
| 1591 |
+
# Deduplicate
|
| 1592 |
+
seen_decl = set()
|
| 1593 |
+
unique_declared = []
|
| 1594 |
+
for e in self.declared_entities:
|
| 1595 |
+
key = (e.get("name"), e.get("type"), e.get("dtype", None))
|
| 1596 |
+
if key not in seen_decl:
|
| 1597 |
+
unique_declared.append(e)
|
| 1598 |
+
seen_decl.add(key)
|
| 1599 |
+
|
| 1600 |
+
unique_called = list(dict.fromkeys(self.called_entities))
|
| 1601 |
+
|
| 1602 |
+
return unique_declared, unique_called
|
| 1603 |
+
|
| 1604 |
+
|
| 1605 |
+
class PythonASTEntityExtractor(ast.NodeVisitor, BaseASTEntityExtractor):
|
| 1606 |
+
"""
|
| 1607 |
+
AST-based entity extractor for Python code.
|
| 1608 |
+
Also detects API endpoint definitions (FastAPI, Flask, Django REST Framework).
|
| 1609 |
+
"""
|
| 1610 |
+
|
| 1611 |
+
# Common HTTP decorators/patterns for Python web frameworks
|
| 1612 |
+
API_DECORATORS = {
|
| 1613 |
+
'route', # Flask @app.route
|
| 1614 |
+
'get', 'post', 'put', 'patch', 'delete', 'head', 'options', # FastAPI/Flask methods
|
| 1615 |
+
'api_view', # DRF @api_view
|
| 1616 |
+
}
|
| 1617 |
+
|
| 1618 |
+
def __init__(self):
|
| 1619 |
+
self.declared_entities: List[Dict[str, Any]] = []
|
| 1620 |
+
self.called_entities: List[str] = []
|
| 1621 |
+
self.current_class: Optional[str] = None
|
| 1622 |
+
self.current_function: Optional[str] = None
|
| 1623 |
+
self.api_endpoints: List[Dict[str, Any]] = [] # Track API endpoint definitions
|
| 1624 |
+
|
| 1625 |
+
def reset(self) -> None:
|
| 1626 |
+
"""Clear previous extraction state including context"""
|
| 1627 |
+
self.declared_entities = []
|
| 1628 |
+
self.called_entities = []
|
| 1629 |
+
self.current_class = None
|
| 1630 |
+
self.current_function = None
|
| 1631 |
+
self.api_endpoints = []
|
| 1632 |
+
|
| 1633 |
+
def _get_type_annotation(self, node: ast.AST) -> str:
|
| 1634 |
+
"""Extract type annotation from AST node"""
|
| 1635 |
+
if isinstance(node, ast.Name):
|
| 1636 |
+
return node.id
|
| 1637 |
+
elif isinstance(node, ast.Constant):
|
| 1638 |
+
return type(node.value).__name__
|
| 1639 |
+
elif isinstance(node, ast.Attribute):
|
| 1640 |
+
return f"{self._get_type_annotation(node.value)}.{node.attr}"
|
| 1641 |
+
elif isinstance(node, ast.Subscript):
|
| 1642 |
+
# Handle generic types like List[str], Dict[str, int]
|
| 1643 |
+
base = self._get_type_annotation(node.value)
|
| 1644 |
+
if isinstance(node.slice, ast.Tuple):
|
| 1645 |
+
args = [self._get_type_annotation(elt) for elt in node.slice.elts]
|
| 1646 |
+
return f"{base}[{', '.join(args)}]"
|
| 1647 |
+
else:
|
| 1648 |
+
arg = self._get_type_annotation(node.slice)
|
| 1649 |
+
return f"{base}[{arg}]"
|
| 1650 |
+
return "unknown"
|
| 1651 |
+
|
| 1652 |
+
def _infer_type_from_value(self, node: ast.AST) -> str:
|
| 1653 |
+
"""Infer type from assigned value"""
|
| 1654 |
+
if isinstance(node, ast.Constant):
|
| 1655 |
+
return type(node.value).__name__
|
| 1656 |
+
elif isinstance(node, ast.List):
|
| 1657 |
+
return "list"
|
| 1658 |
+
elif isinstance(node, ast.Dict):
|
| 1659 |
+
return "dict"
|
| 1660 |
+
elif isinstance(node, ast.Set):
|
| 1661 |
+
return "set"
|
| 1662 |
+
elif isinstance(node, ast.Tuple):
|
| 1663 |
+
return "tuple"
|
| 1664 |
+
elif isinstance(node, ast.Call):
|
| 1665 |
+
if isinstance(node.func, ast.Name):
|
| 1666 |
+
return node.func.id # Constructor call
|
| 1667 |
+
elif isinstance(node.func, ast.Attribute):
|
| 1668 |
+
return "unknown"
|
| 1669 |
+
elif isinstance(node, ast.Name):
|
| 1670 |
+
return "unknown" # Reference to another variable
|
| 1671 |
+
return "unknown"
|
| 1672 |
+
|
| 1673 |
+
def visit_ClassDef(self, node: ast.ClassDef):
|
| 1674 |
+
"""Visit class definitions"""
|
| 1675 |
+
old_class = self.current_class
|
| 1676 |
+
self.current_class = node.name
|
| 1677 |
+
|
| 1678 |
+
# Add class to declared entities
|
| 1679 |
+
self.declared_entities.append({
|
| 1680 |
+
"name": node.name,
|
| 1681 |
+
"type": "class"
|
| 1682 |
+
})
|
| 1683 |
+
|
| 1684 |
+
# Record base classes as called entities
|
| 1685 |
+
for base in node.bases:
|
| 1686 |
+
if isinstance(base, ast.Name):
|
| 1687 |
+
self.called_entities.append(base.id)
|
| 1688 |
+
elif isinstance(base, ast.Attribute):
|
| 1689 |
+
self.called_entities.append(self._get_type_annotation(base))
|
| 1690 |
+
|
| 1691 |
+
# Continue visiting child nodes
|
| 1692 |
+
self.generic_visit(node)
|
| 1693 |
+
self.current_class = old_class
|
| 1694 |
+
|
| 1695 |
+
def visit_FunctionDef(self, node: ast.FunctionDef):
|
| 1696 |
+
"""Visit function/method definitions and detect API endpoints"""
|
| 1697 |
+
old_function = self.current_function
|
| 1698 |
+
|
| 1699 |
+
if self.current_class:
|
| 1700 |
+
# This is a method
|
| 1701 |
+
full_name = f"{self.current_class}.{node.name}"
|
| 1702 |
+
entity_type = "method"
|
| 1703 |
+
else:
|
| 1704 |
+
# This is a function
|
| 1705 |
+
full_name = node.name
|
| 1706 |
+
entity_type = "function"
|
| 1707 |
+
|
| 1708 |
+
self.current_function = full_name
|
| 1709 |
+
|
| 1710 |
+
# Check for API endpoint decorators
|
| 1711 |
+
api_info = self._extract_api_endpoint_from_decorators(node.decorator_list, full_name)
|
| 1712 |
+
if api_info:
|
| 1713 |
+
# Mark this as an API endpoint
|
| 1714 |
+
self.declared_entities.append({
|
| 1715 |
+
"name": full_name,
|
| 1716 |
+
"type": "api_endpoint",
|
| 1717 |
+
"endpoint": api_info.get("endpoint"),
|
| 1718 |
+
"methods": api_info.get("methods")
|
| 1719 |
+
})
|
| 1720 |
+
self.api_endpoints.append(api_info)
|
| 1721 |
+
else:
|
| 1722 |
+
self.declared_entities.append({
|
| 1723 |
+
"name": full_name,
|
| 1724 |
+
"type": entity_type
|
| 1725 |
+
})
|
| 1726 |
+
|
| 1727 |
+
# Process parameters
|
| 1728 |
+
for arg in node.args.args:
|
| 1729 |
+
if arg.arg == 'self' and self.current_class:
|
| 1730 |
+
continue # Skip self parameter
|
| 1731 |
+
|
| 1732 |
+
dtype = "unknown"
|
| 1733 |
+
if arg.annotation:
|
| 1734 |
+
dtype = self._get_type_annotation(arg.annotation)
|
| 1735 |
+
|
| 1736 |
+
param_name = f"{full_name}.{arg.arg}" if entity_type == "method" else arg.arg
|
| 1737 |
+
self.declared_entities.append({
|
| 1738 |
+
"name": param_name,
|
| 1739 |
+
"type": "variable",
|
| 1740 |
+
"dtype": dtype
|
| 1741 |
+
})
|
| 1742 |
+
|
| 1743 |
+
# Continue visiting child nodes
|
| 1744 |
+
self.generic_visit(node)
|
| 1745 |
+
self.current_function = old_function
|
| 1746 |
+
|
| 1747 |
+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
|
| 1748 |
+
"""Visit async function/method definitions"""
|
| 1749 |
+
# Treat async functions the same as regular functions
|
| 1750 |
+
self.visit_FunctionDef(node)
|
| 1751 |
+
|
| 1752 |
+
def visit_Assign(self, node: ast.Assign):
|
| 1753 |
+
"""Visit assignment statements"""
|
| 1754 |
+
# Infer type from the assigned value
|
| 1755 |
+
dtype = self._infer_type_from_value(node.value)
|
| 1756 |
+
|
| 1757 |
+
for target in node.targets:
|
| 1758 |
+
if isinstance(target, ast.Name):
|
| 1759 |
+
# Simple variable assignment
|
| 1760 |
+
var_name = target.id
|
| 1761 |
+
if self.current_class and self.current_function and self.current_function.startswith(self.current_class):
|
| 1762 |
+
# Local variable in method
|
| 1763 |
+
pass # Could add local variables if needed
|
| 1764 |
+
else:
|
| 1765 |
+
# Module-level variable
|
| 1766 |
+
self.declared_entities.append({
|
| 1767 |
+
"name": var_name,
|
| 1768 |
+
"type": "variable",
|
| 1769 |
+
"dtype": dtype
|
| 1770 |
+
})
|
| 1771 |
+
|
| 1772 |
+
elif isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name):
|
| 1773 |
+
# Attribute assignment like self.name = value
|
| 1774 |
+
if target.value.id == 'self' and self.current_class:
|
| 1775 |
+
attr_name = f"{self.current_class}.{target.attr}"
|
| 1776 |
+
self.declared_entities.append({
|
| 1777 |
+
"name": attr_name,
|
| 1778 |
+
"type": "variable",
|
| 1779 |
+
"dtype": dtype
|
| 1780 |
+
})
|
| 1781 |
+
|
| 1782 |
+
# Continue visiting to catch function calls in the assignment
|
| 1783 |
+
self.generic_visit(node)
|
| 1784 |
+
|
| 1785 |
+
def visit_AnnAssign(self, node: ast.AnnAssign):
|
| 1786 |
+
"""Visit annotated assignment statements (PEP 526)"""
|
| 1787 |
+
if isinstance(node.target, ast.Name):
|
| 1788 |
+
dtype = self._get_type_annotation(node.annotation)
|
| 1789 |
+
var_name = node.target.id
|
| 1790 |
+
|
| 1791 |
+
self.declared_entities.append({
|
| 1792 |
+
"name": var_name,
|
| 1793 |
+
"type": "variable",
|
| 1794 |
+
"dtype": dtype
|
| 1795 |
+
})
|
| 1796 |
+
|
| 1797 |
+
elif isinstance(node.target, ast.Attribute) and isinstance(node.target.value, ast.Name):
|
| 1798 |
+
if node.target.value.id == 'self' and self.current_class:
|
| 1799 |
+
dtype = self._get_type_annotation(node.annotation)
|
| 1800 |
+
attr_name = f"{self.current_class}.{node.target.attr}"
|
| 1801 |
+
self.declared_entities.append({
|
| 1802 |
+
"name": attr_name,
|
| 1803 |
+
"type": "variable",
|
| 1804 |
+
"dtype": dtype
|
| 1805 |
+
})
|
| 1806 |
+
|
| 1807 |
+
# Continue visiting
|
| 1808 |
+
if node.value:
|
| 1809 |
+
self.generic_visit(node)
|
| 1810 |
+
|
| 1811 |
+
def visit_Import(self, node: ast.Import):
|
| 1812 |
+
"""Visit import statements"""
|
| 1813 |
+
for alias in node.names:
|
| 1814 |
+
# Record the imported module/package
|
| 1815 |
+
self.called_entities.append(alias.name)
|
| 1816 |
+
self.generic_visit(node)
|
| 1817 |
+
|
| 1818 |
+
def visit_ImportFrom(self, node: ast.ImportFrom):
|
| 1819 |
+
"""Visit from...import statements"""
|
| 1820 |
+
if node.module:
|
| 1821 |
+
# Record the module being imported from
|
| 1822 |
+
self.called_entities.append(node.module)
|
| 1823 |
+
# Optionally, also record specific imports as module.name
|
| 1824 |
+
for alias in node.names:
|
| 1825 |
+
if alias.name != '*':
|
| 1826 |
+
self.called_entities.append(f"{node.module}.{alias.name}")
|
| 1827 |
+
else:
|
| 1828 |
+
# Relative imports without module (from . import x)
|
| 1829 |
+
for alias in node.names:
|
| 1830 |
+
if alias.name != '*':
|
| 1831 |
+
self.called_entities.append(alias.name)
|
| 1832 |
+
self.generic_visit(node)
|
| 1833 |
+
|
| 1834 |
+
def visit_Call(self, node: ast.Call):
|
| 1835 |
+
"""Visit function/method calls"""
|
| 1836 |
+
if isinstance(node.func, ast.Name):
|
| 1837 |
+
# Simple function call
|
| 1838 |
+
self.called_entities.append(node.func.id)
|
| 1839 |
+
|
| 1840 |
+
elif isinstance(node.func, ast.Attribute):
|
| 1841 |
+
# Method call or attribute access
|
| 1842 |
+
if isinstance(node.func.value, ast.Name):
|
| 1843 |
+
# obj.method() - we need to infer the class of obj
|
| 1844 |
+
# For now, just record the method name
|
| 1845 |
+
method_name = node.func.attr
|
| 1846 |
+
# Try to find the variable type from our declared entities
|
| 1847 |
+
obj_name = node.func.value.id
|
| 1848 |
+
obj_class = self._find_variable_type(obj_name)
|
| 1849 |
+
if obj_class and obj_class != "unknown":
|
| 1850 |
+
self.called_entities.append(f"{obj_class}.{method_name}")
|
| 1851 |
+
else:
|
| 1852 |
+
# Fallback: just record the method call
|
| 1853 |
+
self.called_entities.append(method_name)
|
| 1854 |
+
|
| 1855 |
+
elif isinstance(node.func.value, ast.Attribute):
|
| 1856 |
+
# Nested attribute access like module.Class.method()
|
| 1857 |
+
full_name = self._get_type_annotation(node.func)
|
| 1858 |
+
self.called_entities.append(full_name)
|
| 1859 |
+
|
| 1860 |
+
# Continue visiting child nodes
|
| 1861 |
+
self.generic_visit(node)
|
| 1862 |
+
|
| 1863 |
+
def _find_variable_type(self, var_name: str) -> str:
|
| 1864 |
+
"""Find the type of a variable from declared entities"""
|
| 1865 |
+
for entity in self.declared_entities:
|
| 1866 |
+
if entity["name"] == var_name and entity["type"] == "variable":
|
| 1867 |
+
return entity.get("dtype", "unknown")
|
| 1868 |
+
return "unknown"
|
| 1869 |
+
|
| 1870 |
+
def _extract_api_endpoint_from_decorators(self, decorators: List[ast.expr], function_name: str) -> Optional[Dict[str, Any]]:
|
| 1871 |
+
"""
|
| 1872 |
+
Extract API endpoint information from function decorators.
|
| 1873 |
+
Handles patterns like:
|
| 1874 |
+
- @app.route("/api/users", methods=["GET", "POST"]) # Flask
|
| 1875 |
+
- @app.get("/api/users") # FastAPI
|
| 1876 |
+
- @router.post("/api/users") # FastAPI with router
|
| 1877 |
+
- @api_view(['GET', 'POST']) # Django REST Framework
|
| 1878 |
+
"""
|
| 1879 |
+
for decorator in decorators:
|
| 1880 |
+
# Handle @app.route(...) or @app.get(...)
|
| 1881 |
+
if isinstance(decorator, ast.Call):
|
| 1882 |
+
if isinstance(decorator.func, ast.Attribute):
|
| 1883 |
+
# e.g., app.route, app.get, router.post
|
| 1884 |
+
method_name = decorator.func.attr.lower()
|
| 1885 |
+
|
| 1886 |
+
if method_name in self.API_DECORATORS:
|
| 1887 |
+
endpoint = None
|
| 1888 |
+
http_methods = []
|
| 1889 |
+
|
| 1890 |
+
# Extract endpoint from first positional argument
|
| 1891 |
+
if decorator.args and isinstance(decorator.args[0], ast.Constant):
|
| 1892 |
+
endpoint = decorator.args[0].value
|
| 1893 |
+
|
| 1894 |
+
# For FastAPI-style decorators (@app.get, @app.post)
|
| 1895 |
+
if method_name in {'get', 'post', 'put', 'patch', 'delete', 'head', 'options'}:
|
| 1896 |
+
http_methods = [method_name.upper()]
|
| 1897 |
+
|
| 1898 |
+
# For Flask-style @app.route with methods kwarg
|
| 1899 |
+
elif method_name == 'route':
|
| 1900 |
+
for keyword in decorator.keywords:
|
| 1901 |
+
if keyword.arg == 'methods':
|
| 1902 |
+
if isinstance(keyword.value, ast.List):
|
| 1903 |
+
http_methods = [
|
| 1904 |
+
elt.value for elt in keyword.value.elts
|
| 1905 |
+
if isinstance(elt, ast.Constant) and isinstance(elt.value, str)
|
| 1906 |
+
]
|
| 1907 |
+
if not http_methods:
|
| 1908 |
+
http_methods = ['GET'] # Flask default
|
| 1909 |
+
|
| 1910 |
+
# For DRF @api_view(['GET', 'POST'])
|
| 1911 |
+
elif method_name == 'api_view':
|
| 1912 |
+
if decorator.args and isinstance(decorator.args[0], ast.List):
|
| 1913 |
+
http_methods = [
|
| 1914 |
+
elt.value for elt in decorator.args[0].elts
|
| 1915 |
+
if isinstance(elt, ast.Constant) and isinstance(elt.value, str)
|
| 1916 |
+
]
|
| 1917 |
+
|
| 1918 |
+
if endpoint:
|
| 1919 |
+
return {
|
| 1920 |
+
"function": function_name,
|
| 1921 |
+
"endpoint": endpoint,
|
| 1922 |
+
"methods": http_methods,
|
| 1923 |
+
"type": "api_endpoint_definition"
|
| 1924 |
+
}
|
| 1925 |
+
|
| 1926 |
+
return None
|
| 1927 |
+
|
| 1928 |
+
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
|
| 1929 |
+
"""
|
| 1930 |
+
Extract entities from Python code using AST parsing
|
| 1931 |
+
|
| 1932 |
+
Args:
|
| 1933 |
+
code: Python source code as string
|
| 1934 |
+
file_path: Optional path to the source file (for context)
|
| 1935 |
+
|
| 1936 |
+
Returns:
|
| 1937 |
+
Tuple of (declared_entities, called_entities)
|
| 1938 |
+
"""
|
| 1939 |
+
# Ensure fresh state on each extraction
|
| 1940 |
+
self.reset()
|
| 1941 |
+
|
| 1942 |
+
try:
|
| 1943 |
+
tree = ast.parse(code)
|
| 1944 |
+
self.visit(tree)
|
| 1945 |
+
|
| 1946 |
+
# Remove duplicates while preserving order
|
| 1947 |
+
seen_declared = set()
|
| 1948 |
+
unique_declared = []
|
| 1949 |
+
for entity in self.declared_entities:
|
| 1950 |
+
key = (entity["name"], entity["type"], entity.get("dtype"))
|
| 1951 |
+
if key not in seen_declared:
|
| 1952 |
+
unique_declared.append(entity)
|
| 1953 |
+
seen_declared.add(key)
|
| 1954 |
+
|
| 1955 |
+
unique_called = list(dict.fromkeys(self.called_entities)) # Remove duplicates
|
| 1956 |
+
|
| 1957 |
+
return unique_declared, unique_called
|
| 1958 |
+
|
| 1959 |
+
except SyntaxError as e:
|
| 1960 |
+
logger.error(f"Syntax error in Python code: {e}")
|
| 1961 |
+
return [], []
|
| 1962 |
+
except Exception as e:
|
| 1963 |
+
logger.error(f"Error parsing Python code: {e}", exc_info=True)
|
| 1964 |
+
return [], []
|
| 1965 |
+
|
| 1966 |
+
|
| 1967 |
+
class HybridEntityExtractor:
|
| 1968 |
+
"""
|
| 1969 |
+
Hybrid entity extractor that uses AST for known languages,
|
| 1970 |
+
falls back to LLM for unknown ones
|
| 1971 |
+
"""
|
| 1972 |
+
|
| 1973 |
+
def __init__(self):
|
| 1974 |
+
self.extractors = {
|
| 1975 |
+
'py': PythonASTEntityExtractor(),
|
| 1976 |
+
'c': CEntityExtractor(),
|
| 1977 |
+
'h': CppEntityExtractor(), # C/C++ headers
|
| 1978 |
+
'cpp': CppEntityExtractor(),
|
| 1979 |
+
'cc': CppEntityExtractor(),
|
| 1980 |
+
'cxx': CppEntityExtractor(),
|
| 1981 |
+
'hpp': CppEntityExtractor(),
|
| 1982 |
+
'hxx': CppEntityExtractor(),
|
| 1983 |
+
'hh': CppEntityExtractor(),
|
| 1984 |
+
'java': JavaEntityExtractor(),
|
| 1985 |
+
'js': JavaScriptEntityExtractor(), # β
NEW
|
| 1986 |
+
'jsx': JavaScriptEntityExtractor(), # β
NEW
|
| 1987 |
+
'ts': JavaScriptEntityExtractor(), # TypeScript uses similar AST
|
| 1988 |
+
'tsx': JavaScriptEntityExtractor(), # TSX similar to JSX
|
| 1989 |
+
'rs': RustEntityExtractor(),
|
| 1990 |
+
'html': HTMLEntityExtractor()
|
| 1991 |
+
}
|
| 1992 |
+
|
| 1993 |
+
def _get_language_from_filename(self, file_name: str) -> str:
|
| 1994 |
+
ext = file_name.split('.')[-1].lower()
|
| 1995 |
+
return ext
|
| 1996 |
+
|
| 1997 |
+
def extract_entities(self, code: str, file_name: str):
|
| 1998 |
+
|
| 1999 |
+
lang = self._get_language_from_filename(file_name)
|
| 2000 |
+
extractor = self.extractors.get(lang)
|
| 2001 |
+
|
| 2002 |
+
if extractor:
|
| 2003 |
+
# Reset the shared extractor instance to ensure no state is carried over
|
| 2004 |
+
try:
|
| 2005 |
+
extractor.reset()
|
| 2006 |
+
except Exception:
|
| 2007 |
+
# If extractor doesn't implement reset for some reason, ignore and proceed
|
| 2008 |
+
pass
|
| 2009 |
+
|
| 2010 |
+
logger.info(f"Using AST extraction for {lang.upper()} file: {file_name}")
|
| 2011 |
+
try:
|
| 2012 |
+
# Try to pass file_name if the extractor supports it (C++ extractor does)
|
| 2013 |
+
try:
|
| 2014 |
+
declared_entities, called_entities = extractor.extract_entities(code, file_path=file_name)
|
| 2015 |
+
except TypeError:
|
| 2016 |
+
# Fallback for extractors that don't accept file_path parameter
|
| 2017 |
+
declared_entities, called_entities = extractor.extract_entities(code)
|
| 2018 |
+
|
| 2019 |
+
# Add aliases to each declared entity based on file path
|
| 2020 |
+
for entity in declared_entities:
|
| 2021 |
+
entity_name = entity.get('name', '')
|
| 2022 |
+
if entity_name:
|
| 2023 |
+
aliases = generate_entity_aliases(entity_name, file_name)
|
| 2024 |
+
entity['aliases'] = aliases
|
| 2025 |
+
logger.debug(f"Generated aliases for entity '{entity_name}': {aliases}")
|
| 2026 |
+
|
| 2027 |
+
return declared_entities, called_entities
|
| 2028 |
+
except Exception as e:
|
| 2029 |
+
logger.error(f"Error during AST extraction for file {file_name}: {e}", exc_info=True)
|
| 2030 |
+
return [], []
|
| 2031 |
+
else:
|
| 2032 |
+
raise Exception(f"Using LLM extraction for unsupported language: {file_name}")
|
RepoKnowledgeGraphLib/KnowledgeGraphMCPServer.py
ADDED
|
@@ -0,0 +1,1107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional, Annotated
|
| 3 |
+
from fastmcp import FastMCP
|
| 4 |
+
from langfuse import get_client, observe
|
| 5 |
+
|
| 6 |
+
from .RepoKnowledgeGraph import RepoKnowledgeGraph
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Custom Exceptions
|
| 10 |
+
class MCPServerError(Exception):
|
| 11 |
+
"""Base exception for MCP server errors"""
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class NodeNotFoundError(MCPServerError):
|
| 16 |
+
"""Raised when a node is not found"""
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class EntityNotFoundError(MCPServerError):
|
| 21 |
+
"""Raised when an entity is not found"""
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class InvalidInputError(MCPServerError):
|
| 26 |
+
"""Raised when input validation fails"""
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class KnowledgeGraphMCPServer:
|
| 31 |
+
"""
|
| 32 |
+
MCP Server for interacting with a codebase knowledge graph.
|
| 33 |
+
|
| 34 |
+
Attributes:
|
| 35 |
+
knowledge_graph (RepoKnowledgeGraph): The loaded knowledge graph object.
|
| 36 |
+
app (FastMCP): The FastMCP application instance for tool registration and serving.
|
| 37 |
+
"""
|
| 38 |
+
def __init__(self, knowledge_graph: Optional[RepoKnowledgeGraph] = None, knowledge_graph_path: Optional[str] = None, server_name: str = "knowledge-graph-mcp-server"):
|
| 39 |
+
if knowledge_graph is not None:
|
| 40 |
+
self.knowledge_graph = knowledge_graph
|
| 41 |
+
else:
|
| 42 |
+
if knowledge_graph_path is None:
|
| 43 |
+
knowledge_graph_path = os.path.join(os.path.dirname(__file__), "knowledge_graph.json")
|
| 44 |
+
self.knowledge_graph = RepoKnowledgeGraph.load_graph_from_file(knowledge_graph_path)
|
| 45 |
+
self.langfuse = get_client()
|
| 46 |
+
self.app = FastMCP(server_name)
|
| 47 |
+
self.register_tools()
|
| 48 |
+
|
| 49 |
+
def _validate_node_exists(self, node_id: str) -> bool:
|
| 50 |
+
"""Centralized node validation"""
|
| 51 |
+
if node_id not in self.knowledge_graph.graph:
|
| 52 |
+
raise NodeNotFoundError(f"Node '{node_id}' not found in knowledge graph")
|
| 53 |
+
return True
|
| 54 |
+
|
| 55 |
+
def _validate_entity_exists(self, entity_name: str) -> bool:
|
| 56 |
+
"""Centralized entity validation"""
|
| 57 |
+
if entity_name not in self.knowledge_graph.entities:
|
| 58 |
+
raise EntityNotFoundError(f"Entity '{entity_name}' not found in knowledge graph")
|
| 59 |
+
return True
|
| 60 |
+
|
| 61 |
+
def _validate_positive_int(self, value: int, param_name: str) -> bool:
|
| 62 |
+
"""Validate that an integer parameter is positive"""
|
| 63 |
+
if value <= 0:
|
| 64 |
+
raise InvalidInputError(f"{param_name} must be a positive integer, got {value}")
|
| 65 |
+
return True
|
| 66 |
+
|
| 67 |
+
def _sanitize_chunk_dict(self, chunk_dict: dict) -> dict:
|
| 68 |
+
"""Remove embedding data from chunk dictionary before returning to user"""
|
| 69 |
+
sanitized = chunk_dict.copy()
|
| 70 |
+
sanitized.pop('embedding', None)
|
| 71 |
+
return sanitized
|
| 72 |
+
|
| 73 |
+
def _sanitize_node_dict(self, node_dict: dict) -> dict:
|
| 74 |
+
"""Remove embedding data from node dictionary before returning to user"""
|
| 75 |
+
sanitized = node_dict.copy()
|
| 76 |
+
if 'data' in sanitized and isinstance(sanitized['data'], dict):
|
| 77 |
+
sanitized['data'] = sanitized['data'].copy()
|
| 78 |
+
sanitized['data'].pop('embedding', None)
|
| 79 |
+
sanitized.pop('embedding', None)
|
| 80 |
+
return sanitized
|
| 81 |
+
|
| 82 |
+
def _handle_error(self, error: Exception, context: str = "") -> dict:
|
| 83 |
+
"""Centralized error handling with structured response"""
|
| 84 |
+
if isinstance(error, NodeNotFoundError):
|
| 85 |
+
return {
|
| 86 |
+
"error": str(error),
|
| 87 |
+
"error_type": "node_not_found",
|
| 88 |
+
"context": context
|
| 89 |
+
}
|
| 90 |
+
elif isinstance(error, EntityNotFoundError):
|
| 91 |
+
return {
|
| 92 |
+
"error": str(error),
|
| 93 |
+
"error_type": "entity_not_found",
|
| 94 |
+
"context": context
|
| 95 |
+
}
|
| 96 |
+
elif isinstance(error, InvalidInputError):
|
| 97 |
+
return {
|
| 98 |
+
"error": str(error),
|
| 99 |
+
"error_type": "invalid_input",
|
| 100 |
+
"context": context
|
| 101 |
+
}
|
| 102 |
+
else:
|
| 103 |
+
return {
|
| 104 |
+
"error": str(error),
|
| 105 |
+
"error_type": "internal_error",
|
| 106 |
+
"context": context
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
@classmethod
|
| 110 |
+
def from_path(cls, path: str, skip_dirs=None, index_nodes=True, describe_nodes=False, extract_entities=False, model_service_kwargs=None, code_index_kwargs=None, server_name: str = "knowledge-graph-mcp-server"):
|
| 111 |
+
"""
|
| 112 |
+
Build a KnowledgeGraphMCPServer from a code repository path.
|
| 113 |
+
"""
|
| 114 |
+
if skip_dirs is None:
|
| 115 |
+
skip_dirs = []
|
| 116 |
+
if model_service_kwargs is None:
|
| 117 |
+
model_service_kwargs = {}
|
| 118 |
+
kg = RepoKnowledgeGraph.from_path(path, skip_dirs=skip_dirs, index_nodes=index_nodes, describe_nodes=describe_nodes, extract_entities=extract_entities, model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs)
|
| 119 |
+
return cls(knowledge_graph=kg, server_name=server_name)
|
| 120 |
+
|
| 121 |
+
@classmethod
|
| 122 |
+
def from_file(cls, filepath: str, index_nodes=True, use_embed=True, model_service_kwargs=None, code_index_kwargs = None, server_name: str = "knowledge-graph-mcp-server"):
|
| 123 |
+
"""
|
| 124 |
+
Build a KnowledgeGraphMCPServer from a serialized knowledge graph file.
|
| 125 |
+
"""
|
| 126 |
+
if model_service_kwargs is None:
|
| 127 |
+
model_service_kwargs = {}
|
| 128 |
+
kg = RepoKnowledgeGraph.load_graph_from_file(filepath, index_nodes=index_nodes, use_embed=use_embed, model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs)
|
| 129 |
+
return cls(knowledge_graph=kg, server_name=server_name)
|
| 130 |
+
|
| 131 |
+
@classmethod
|
| 132 |
+
def from_repo(cls, repo_url: str, index_nodes=True, describe_nodes=False, model_service_kwargs=None, code_index_kwargs=None, server_name: str = "knowledge-graph-mcp-server", github_token=None, allow_unauthenticated_clone=True, skip_dirs=None, extract_entities=True):
|
| 133 |
+
if model_service_kwargs is None:
|
| 134 |
+
model_service_kwargs = {}
|
| 135 |
+
kg = RepoKnowledgeGraph.from_repo(repo_url=repo_url, describe_nodes=describe_nodes, index_nodes=index_nodes, model_service_kwargs=model_service_kwargs, github_token=github_token, allow_unauthenticated_clone=allow_unauthenticated_clone, skip_dirs=skip_dirs, extract_entities=extract_entities, code_index_kwargs=code_index_kwargs)
|
| 136 |
+
return cls(knowledge_graph=kg, server_name=server_name)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def register_tools(self):
|
| 140 |
+
@self.app.tool(
|
| 141 |
+
description="Get detailed information about a node in the knowledge graph, including its type, name, description, declared and called entities, and a content preview."
|
| 142 |
+
)
|
| 143 |
+
@observe(as_type='tool')
|
| 144 |
+
async def get_node_info(
|
| 145 |
+
node_id: Annotated[str, "The ID of the node to retrieve information for."]
|
| 146 |
+
) -> dict:
|
| 147 |
+
try:
|
| 148 |
+
self._validate_node_exists(node_id)
|
| 149 |
+
node = self.knowledge_graph.graph.nodes[node_id]['data']
|
| 150 |
+
|
| 151 |
+
declared_entities = getattr(node, 'declared_entities', [])
|
| 152 |
+
called_entities = getattr(node, 'called_entities', [])
|
| 153 |
+
content = getattr(node, 'content', None)
|
| 154 |
+
content_preview = content[:200] + "..." if content and len(content) > 200 else content
|
| 155 |
+
|
| 156 |
+
return {
|
| 157 |
+
"node_id": node_id,
|
| 158 |
+
"class": node.__class__.__name__,
|
| 159 |
+
"name": getattr(node, 'name', 'Unknown'),
|
| 160 |
+
"type": getattr(node, 'node_type', 'Unknown'),
|
| 161 |
+
"description": getattr(node, 'description', None),
|
| 162 |
+
"declared_entities": declared_entities,
|
| 163 |
+
"called_entities": called_entities,
|
| 164 |
+
"content_preview": content_preview,
|
| 165 |
+
"text": f"Node {node_id} ({getattr(node, 'name', '?')}) β {getattr(node, 'node_type', '?')} with {len(declared_entities)} declared and {len(called_entities)} called entities."
|
| 166 |
+
}
|
| 167 |
+
except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
|
| 168 |
+
return self._handle_error(e, "get_node_info")
|
| 169 |
+
except Exception as e:
|
| 170 |
+
return self._handle_error(e, "get_node_info")
|
| 171 |
+
|
| 172 |
+
@self.app.tool(
|
| 173 |
+
description="List all incoming and outgoing edges for a node, showing relationships to other nodes."
|
| 174 |
+
)
|
| 175 |
+
@observe(as_type='tool')
|
| 176 |
+
async def get_node_edges(
|
| 177 |
+
node_id: Annotated[str, "The ID of the node whose edges to list."]
|
| 178 |
+
) -> dict:
|
| 179 |
+
try:
|
| 180 |
+
self._validate_node_exists(node_id)
|
| 181 |
+
g = self.knowledge_graph.graph
|
| 182 |
+
|
| 183 |
+
incoming = [
|
| 184 |
+
{"source": src, "target": tgt, "relation": data.get("relation", "?")}
|
| 185 |
+
for src, tgt, data in g.in_edges(node_id, data=True)
|
| 186 |
+
]
|
| 187 |
+
outgoing = [
|
| 188 |
+
{"source": src, "target": tgt, "relation": data.get("relation", "?")}
|
| 189 |
+
for src, tgt, data in g.out_edges(node_id, data=True)
|
| 190 |
+
]
|
| 191 |
+
|
| 192 |
+
return {
|
| 193 |
+
"node_id": node_id,
|
| 194 |
+
"incoming": incoming,
|
| 195 |
+
"outgoing": outgoing,
|
| 196 |
+
"incoming_count": len(incoming),
|
| 197 |
+
"outgoing_count": len(outgoing),
|
| 198 |
+
"text": f"Node '{node_id}' has {len(incoming)} incoming and {len(outgoing)} outgoing edges."
|
| 199 |
+
}
|
| 200 |
+
except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
|
| 201 |
+
return self._handle_error(e, "get_node_edges")
|
| 202 |
+
except Exception as e:
|
| 203 |
+
return self._handle_error(e, "get_node_edges")
|
| 204 |
+
|
| 205 |
+
@self.app.tool(
|
| 206 |
+
description="Search for nodes in the knowledge graph by query string, using the code index semantic and keyword search."
|
| 207 |
+
)
|
| 208 |
+
@observe(as_type='tool')
|
| 209 |
+
async def search_nodes(
|
| 210 |
+
query: Annotated[str, "The search string to match against code index."],
|
| 211 |
+
limit: Annotated[int, "Maximum number of results to return."] = 10
|
| 212 |
+
) -> dict:
|
| 213 |
+
try:
|
| 214 |
+
self._validate_positive_int(limit, "limit")
|
| 215 |
+
|
| 216 |
+
results = self.knowledge_graph.code_index.query(query, n_results=limit)
|
| 217 |
+
metadatas = results.get("metadatas", [[]])[0]
|
| 218 |
+
|
| 219 |
+
if not metadatas:
|
| 220 |
+
return {"query": query, "results": [], "text": f"No results found for '{query}'."}
|
| 221 |
+
|
| 222 |
+
structured_results = [
|
| 223 |
+
{
|
| 224 |
+
"id": res.get("id"),
|
| 225 |
+
"content": res.get("content"),
|
| 226 |
+
"declared_entities": res.get("declared_entities"),
|
| 227 |
+
"called_entities": res.get("called_entities")
|
| 228 |
+
}
|
| 229 |
+
for res in metadatas
|
| 230 |
+
]
|
| 231 |
+
|
| 232 |
+
return {
|
| 233 |
+
"query": query,
|
| 234 |
+
"count": len(structured_results),
|
| 235 |
+
"results": structured_results,
|
| 236 |
+
"text": f"Found {len(structured_results)} result(s) for query '{query}'."
|
| 237 |
+
}
|
| 238 |
+
except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
|
| 239 |
+
return self._handle_error(e, "search_nodes")
|
| 240 |
+
except Exception as e:
|
| 241 |
+
return self._handle_error(e, "search_nodes")
|
| 242 |
+
|
| 243 |
+
@self.app.tool(
|
| 244 |
+
description="Get overall statistics about the knowledge graph, including node and edge counts, types, and relations."
|
| 245 |
+
)
|
| 246 |
+
@observe(as_type='tool')
|
| 247 |
+
async def get_graph_stats() -> dict:
|
| 248 |
+
g = self.knowledge_graph.graph
|
| 249 |
+
num_nodes = g.number_of_nodes()
|
| 250 |
+
num_edges = g.number_of_edges()
|
| 251 |
+
|
| 252 |
+
node_types = {}
|
| 253 |
+
for _, node_attrs in g.nodes(data=True):
|
| 254 |
+
node_type = getattr(node_attrs['data'], 'node_type', 'Unknown')
|
| 255 |
+
node_types[node_type] = node_types.get(node_type, 0) + 1
|
| 256 |
+
|
| 257 |
+
edge_relations = {}
|
| 258 |
+
for _, _, attrs in g.edges(data=True):
|
| 259 |
+
relation = attrs.get('relation', 'Unknown')
|
| 260 |
+
edge_relations[relation] = edge_relations.get(relation, 0) + 1
|
| 261 |
+
|
| 262 |
+
return {
|
| 263 |
+
"total_nodes": num_nodes,
|
| 264 |
+
"total_edges": num_edges,
|
| 265 |
+
"node_types": node_types,
|
| 266 |
+
"edge_relations": edge_relations,
|
| 267 |
+
"text": f"Graph with {num_nodes} nodes, {num_edges} edges, {len(node_types)} node types, and {len(edge_relations)} relation types."
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
@self.app.tool(
|
| 271 |
+
description="List nodes of a specific type in the knowledge graph."
|
| 272 |
+
)
|
| 273 |
+
@observe(as_type='tool')
|
| 274 |
+
async def list_nodes_by_type(
|
| 275 |
+
node_type: Annotated[str, "The type of nodes to list (e.g., 'function', 'class', 'file')."],
|
| 276 |
+
limit: Annotated[int, "Maximum number of nodes to return."] = 20
|
| 277 |
+
) -> dict:
|
| 278 |
+
g = self.knowledge_graph.graph
|
| 279 |
+
matching_nodes = [
|
| 280 |
+
{
|
| 281 |
+
"id": node_id,
|
| 282 |
+
"name": getattr(data['data'], 'name', 'Unknown')
|
| 283 |
+
}
|
| 284 |
+
for node_id, data in g.nodes(data=True)
|
| 285 |
+
if getattr(data['data'], 'node_type', None) == node_type
|
| 286 |
+
][:limit]
|
| 287 |
+
|
| 288 |
+
if not matching_nodes:
|
| 289 |
+
return {"node_type": node_type, "results": [], "text": f"No nodes found of type '{node_type}'."}
|
| 290 |
+
|
| 291 |
+
return {
|
| 292 |
+
"node_type": node_type,
|
| 293 |
+
"count": len(matching_nodes),
|
| 294 |
+
"results": matching_nodes,
|
| 295 |
+
"text": f"Found {len(matching_nodes)} node(s) of type '{node_type}'."
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
@self.app.tool(
|
| 299 |
+
description="Get all nodes directly connected to a given node, including the relationship type."
|
| 300 |
+
)
|
| 301 |
+
@observe(as_type='tool')
|
| 302 |
+
async def get_neighbors(
|
| 303 |
+
node_id: Annotated[str, "The ID of the node whose neighbors to retrieve."]
|
| 304 |
+
) -> dict:
|
| 305 |
+
"""Get all nodes directly connected to this node, with their relationship types."""
|
| 306 |
+
try:
|
| 307 |
+
self._validate_node_exists(node_id)
|
| 308 |
+
|
| 309 |
+
neighbors = self.knowledge_graph.get_neighbors(node_id)
|
| 310 |
+
if not neighbors:
|
| 311 |
+
return {
|
| 312 |
+
"node_id": node_id,
|
| 313 |
+
"neighbors": [],
|
| 314 |
+
"text": f"No neighbors found for node '{node_id}'"
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
neighbor_list = []
|
| 318 |
+
for neighbor in neighbors[:20]:
|
| 319 |
+
neighbor_info = {
|
| 320 |
+
"id": neighbor.id,
|
| 321 |
+
"name": getattr(neighbor, 'name', 'Unknown'),
|
| 322 |
+
"type": neighbor.node_type,
|
| 323 |
+
"relation": None
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
if self.knowledge_graph.graph.has_edge(node_id, neighbor.id):
|
| 327 |
+
edge_data = self.knowledge_graph.graph.get_edge_data(node_id, neighbor.id)
|
| 328 |
+
neighbor_info["relation"] = edge_data.get('relation', 'Unknown')
|
| 329 |
+
neighbor_info["direction"] = "outgoing"
|
| 330 |
+
elif self.knowledge_graph.graph.has_edge(neighbor.id, node_id):
|
| 331 |
+
edge_data = self.knowledge_graph.graph.get_edge_data(neighbor.id, node_id)
|
| 332 |
+
neighbor_info["relation"] = edge_data.get('relation', 'Unknown')
|
| 333 |
+
neighbor_info["direction"] = "incoming"
|
| 334 |
+
|
| 335 |
+
neighbor_list.append(neighbor_info)
|
| 336 |
+
|
| 337 |
+
text = f"Neighbors of '{node_id}' ({len(neighbors)} total):\n\n"
|
| 338 |
+
for neighbor in neighbor_list:
|
| 339 |
+
text += f"- {neighbor['id']}: {neighbor['name']} ({neighbor['type']})\n"
|
| 340 |
+
if neighbor['relation']:
|
| 341 |
+
arrow = "β" if neighbor['direction'] == "outgoing" else "β"
|
| 342 |
+
text += f" {arrow} Relation: {neighbor['relation']}\n"
|
| 343 |
+
|
| 344 |
+
if len(neighbors) > 20:
|
| 345 |
+
text += f"\n... and {len(neighbors) - 20} more neighbors\n"
|
| 346 |
+
|
| 347 |
+
return {
|
| 348 |
+
"node_id": node_id,
|
| 349 |
+
"total_neighbors": len(neighbors),
|
| 350 |
+
"neighbors": neighbor_list,
|
| 351 |
+
"has_more": len(neighbors) > 20,
|
| 352 |
+
"text": text
|
| 353 |
+
}
|
| 354 |
+
except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
|
| 355 |
+
return self._handle_error(e, "get_neighbors")
|
| 356 |
+
except Exception as e:
|
| 357 |
+
return self._handle_error(e, "get_neighbors")
|
| 358 |
+
|
| 359 |
+
@self.app.tool(
|
| 360 |
+
description="Find where an entity (function, class, variable, etc.) is declared or defined in the codebase."
|
| 361 |
+
)
|
| 362 |
+
@observe(as_type='tool')
|
| 363 |
+
async def go_to_definition(
|
| 364 |
+
entity_name: Annotated[str, "The name of the entity to find the definition for."]
|
| 365 |
+
) -> dict:
|
| 366 |
+
"""Find where an entity is declared/defined in the codebase."""
|
| 367 |
+
try:
|
| 368 |
+
self._validate_entity_exists(entity_name)
|
| 369 |
+
|
| 370 |
+
entity_info = self.knowledge_graph.entities[entity_name]
|
| 371 |
+
declaring_chunks = entity_info.get('declaring_chunk_ids', [])
|
| 372 |
+
|
| 373 |
+
if not declaring_chunks:
|
| 374 |
+
return {
|
| 375 |
+
"entity_name": entity_name,
|
| 376 |
+
"declarations": [],
|
| 377 |
+
"text": f"Entity '{entity_name}' found but no declarations identified."
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
declarations = []
|
| 381 |
+
for chunk_id in declaring_chunks[:5]:
|
| 382 |
+
if chunk_id in self.knowledge_graph.graph:
|
| 383 |
+
chunk = self.knowledge_graph.graph.nodes[chunk_id]['data']
|
| 384 |
+
content_preview = chunk.content[:150] + "..." if len(chunk.content) > 150 else chunk.content
|
| 385 |
+
declarations.append({
|
| 386 |
+
"chunk_id": chunk_id,
|
| 387 |
+
"file_path": chunk.path,
|
| 388 |
+
"order_in_file": chunk.order_in_file,
|
| 389 |
+
"content_preview": content_preview
|
| 390 |
+
})
|
| 391 |
+
|
| 392 |
+
text = f"Definition(s) for '{entity_name}':\n\n"
|
| 393 |
+
text += f"Type: {', '.join(entity_info.get('type', ['Unknown']))}\n"
|
| 394 |
+
if entity_info.get('dtype'):
|
| 395 |
+
text += f"Data Type: {entity_info['dtype']}\n"
|
| 396 |
+
text += f"\nDeclared in {len(declaring_chunks)} location(s):\n\n"
|
| 397 |
+
|
| 398 |
+
for decl in declarations:
|
| 399 |
+
text += f"- Chunk: {decl['chunk_id']}\n"
|
| 400 |
+
text += f" File: {decl['file_path']}\n"
|
| 401 |
+
text += f" Order: {decl['order_in_file']}\n"
|
| 402 |
+
text += f" Content: {decl['content_preview']}\n\n"
|
| 403 |
+
|
| 404 |
+
if len(declaring_chunks) > 5:
|
| 405 |
+
text += f"... and {len(declaring_chunks) - 5} more locations\n"
|
| 406 |
+
|
| 407 |
+
return {
|
| 408 |
+
"entity_name": entity_name,
|
| 409 |
+
"type": entity_info.get('type', []),
|
| 410 |
+
"dtype": entity_info.get('dtype'),
|
| 411 |
+
"total_declarations": len(declaring_chunks),
|
| 412 |
+
"declarations": declarations,
|
| 413 |
+
"has_more": len(declaring_chunks) > 5,
|
| 414 |
+
"text": text
|
| 415 |
+
}
|
| 416 |
+
except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
|
| 417 |
+
return self._handle_error(e, "go_to_definition")
|
| 418 |
+
except Exception as e:
|
| 419 |
+
return self._handle_error(e, "go_to_definition")
|
| 420 |
+
|
| 421 |
+
@self.app.tool(
|
| 422 |
+
description="Find all usages or calls of an entity (function, class, variable, etc.) in the codebase."
|
| 423 |
+
)
|
| 424 |
+
@observe(as_type='tool')
|
| 425 |
+
async def find_usages(
|
| 426 |
+
entity_name: Annotated[str, "The name of the entity to find usages for."],
|
| 427 |
+
limit: Annotated[int, "Maximum number of usages to return."] = 20
|
| 428 |
+
) -> dict:
|
| 429 |
+
"""Find where an entity is used/called in the codebase."""
|
| 430 |
+
try:
|
| 431 |
+
self._validate_entity_exists(entity_name)
|
| 432 |
+
self._validate_positive_int(limit, "limit")
|
| 433 |
+
|
| 434 |
+
entity_info = self.knowledge_graph.entities[entity_name]
|
| 435 |
+
calling_chunks = entity_info.get('calling_chunk_ids', [])
|
| 436 |
+
|
| 437 |
+
if not calling_chunks:
|
| 438 |
+
return {
|
| 439 |
+
"entity_name": entity_name,
|
| 440 |
+
"usages": [],
|
| 441 |
+
"text": f"Entity '{entity_name}' found but no usages identified."
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
usages = []
|
| 445 |
+
for chunk_id in calling_chunks[:limit]:
|
| 446 |
+
if chunk_id in self.knowledge_graph.graph:
|
| 447 |
+
chunk = self.knowledge_graph.graph.nodes[chunk_id]['data']
|
| 448 |
+
content_preview = chunk.content[:150] + "..." if len(chunk.content) > 150 else chunk.content
|
| 449 |
+
usages.append({
|
| 450 |
+
"chunk_id": chunk_id,
|
| 451 |
+
"file_path": chunk.path,
|
| 452 |
+
"order_in_file": chunk.order_in_file,
|
| 453 |
+
"content_preview": content_preview
|
| 454 |
+
})
|
| 455 |
+
|
| 456 |
+
text = f"Usages of '{entity_name}' ({len(calling_chunks)} total):\n\n"
|
| 457 |
+
for usage in usages:
|
| 458 |
+
text += f"- {usage['file_path']} (chunk {usage['order_in_file']})\n"
|
| 459 |
+
text += f" Content: {usage['content_preview']}\n\n"
|
| 460 |
+
|
| 461 |
+
if len(calling_chunks) > limit:
|
| 462 |
+
text += f"\n... and {len(calling_chunks) - limit} more usages\n"
|
| 463 |
+
|
| 464 |
+
return {
|
| 465 |
+
"entity_name": entity_name,
|
| 466 |
+
"total_usages": len(calling_chunks),
|
| 467 |
+
"usages": usages,
|
| 468 |
+
"has_more": len(calling_chunks) > limit,
|
| 469 |
+
"text": text
|
| 470 |
+
}
|
| 471 |
+
except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
|
| 472 |
+
return self._handle_error(e, "find_usages")
|
| 473 |
+
except Exception as e:
|
| 474 |
+
return self._handle_error(e, "find_usages")
|
| 475 |
+
|
| 476 |
+
@self.app.tool(
|
| 477 |
+
description="Get an overview of the structure of a file, including its chunks and declared entities."
|
| 478 |
+
)
|
| 479 |
+
@observe(as_type='tool')
|
| 480 |
+
async def get_file_structure(
|
| 481 |
+
file_path: Annotated[str, "The path of the file to get the structure for."]
|
| 482 |
+
) -> dict:
|
| 483 |
+
"""Get an overview of chunks and entities in a specific file."""
|
| 484 |
+
try:
|
| 485 |
+
self._validate_node_exists(file_path)
|
| 486 |
+
|
| 487 |
+
file_node = self.knowledge_graph.graph.nodes[file_path]['data']
|
| 488 |
+
chunks = self.knowledge_graph.get_chunks_of_file(file_path)
|
| 489 |
+
|
| 490 |
+
declared_entities = []
|
| 491 |
+
if hasattr(file_node, 'declared_entities') and file_node.declared_entities:
|
| 492 |
+
for entity in file_node.declared_entities[:15]:
|
| 493 |
+
if isinstance(entity, dict):
|
| 494 |
+
declared_entities.append({
|
| 495 |
+
"name": entity.get('name', '?'),
|
| 496 |
+
"type": entity.get('type', '?')
|
| 497 |
+
})
|
| 498 |
+
else:
|
| 499 |
+
declared_entities.append({"name": str(entity), "type": "Unknown"})
|
| 500 |
+
|
| 501 |
+
chunk_list = []
|
| 502 |
+
for chunk in chunks[:10]:
|
| 503 |
+
chunk_list.append({
|
| 504 |
+
"id": chunk.id,
|
| 505 |
+
"order_in_file": chunk.order_in_file,
|
| 506 |
+
"description": chunk.description[:80] + "..." if chunk.description and len(chunk.description) > 80 else chunk.description
|
| 507 |
+
})
|
| 508 |
+
|
| 509 |
+
text = f"File Structure: {file_node.name}\n"
|
| 510 |
+
text += f"Path: {file_path}\n"
|
| 511 |
+
text += f"Language: {getattr(file_node, 'language', 'Unknown')}\n"
|
| 512 |
+
text += f"Total Chunks: {len(chunks)}\n\n"
|
| 513 |
+
|
| 514 |
+
if declared_entities:
|
| 515 |
+
text += f"Declared Entities ({len(file_node.declared_entities)}):\n"
|
| 516 |
+
for entity in declared_entities:
|
| 517 |
+
text += f" - {entity['name']} ({entity['type']})\n"
|
| 518 |
+
if len(file_node.declared_entities) > 15:
|
| 519 |
+
text += f" ... and {len(file_node.declared_entities) - 15} more\n"
|
| 520 |
+
|
| 521 |
+
text += f"\nChunks:\n"
|
| 522 |
+
for chunk_info in chunk_list:
|
| 523 |
+
text += f" [{chunk_info['order_in_file']}] {chunk_info['id']}\n"
|
| 524 |
+
if chunk_info['description']:
|
| 525 |
+
text += f" {chunk_info['description']}\n"
|
| 526 |
+
|
| 527 |
+
if len(chunks) > 10:
|
| 528 |
+
text += f" ... and {len(chunks) - 10} more chunks\n"
|
| 529 |
+
|
| 530 |
+
return {
|
| 531 |
+
"file_path": file_path,
|
| 532 |
+
"file_name": file_node.name,
|
| 533 |
+
"language": getattr(file_node, 'language', 'Unknown'),
|
| 534 |
+
"total_chunks": len(chunks),
|
| 535 |
+
"total_declared_entities": len(file_node.declared_entities) if hasattr(file_node, 'declared_entities') else 0,
|
| 536 |
+
"declared_entities": declared_entities,
|
| 537 |
+
"chunks": chunk_list,
|
| 538 |
+
"has_more_entities": hasattr(file_node, 'declared_entities') and len(file_node.declared_entities) > 15,
|
| 539 |
+
"has_more_chunks": len(chunks) > 10,
|
| 540 |
+
"text": text
|
| 541 |
+
}
|
| 542 |
+
except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
|
| 543 |
+
return self._handle_error(e, "get_file_structure")
|
| 544 |
+
except Exception as e:
|
| 545 |
+
return self._handle_error(e, "get_file_structure")
|
| 546 |
+
|
| 547 |
+
@self.app.tool(
|
| 548 |
+
description="Get chunks related to a given chunk by a specific relationship (e.g., 'calls', 'contains')."
|
| 549 |
+
)
|
| 550 |
+
@observe(as_type='tool')
|
| 551 |
+
async def get_related_chunks(
|
| 552 |
+
chunk_id: Annotated[str, "The ID of the chunk to find related chunks for."],
|
| 553 |
+
relation_type: Annotated[str, "The type of relationship to filter by (e.g., 'calls', 'contains')."] = "calls"
|
| 554 |
+
) -> dict:
|
| 555 |
+
"""Get chunks related to this chunk by a specific relationship (e.g., 'calls', 'contains')."""
|
| 556 |
+
try:
|
| 557 |
+
self._validate_node_exists(chunk_id)
|
| 558 |
+
|
| 559 |
+
related = []
|
| 560 |
+
for _, target, attrs in self.knowledge_graph.graph.out_edges(chunk_id, data=True):
|
| 561 |
+
if attrs.get('relation') == relation_type:
|
| 562 |
+
target_node = self.knowledge_graph.graph.nodes[target]['data']
|
| 563 |
+
related.append({
|
| 564 |
+
"id": target,
|
| 565 |
+
"file_path": getattr(target_node, 'path', 'Unknown'),
|
| 566 |
+
"entity_name": attrs.get('entity_name')
|
| 567 |
+
})
|
| 568 |
+
|
| 569 |
+
if not related:
|
| 570 |
+
return {
|
| 571 |
+
"chunk_id": chunk_id,
|
| 572 |
+
"relation_type": relation_type,
|
| 573 |
+
"related_chunks": [],
|
| 574 |
+
"text": f"No chunks found with '{relation_type}' relationship from '{chunk_id}'"
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
text = f"Chunks related to '{chunk_id}' via '{relation_type}' ({len(related)} total):\n\n"
|
| 578 |
+
for chunk in related[:15]:
|
| 579 |
+
text += f"- {chunk['id']}\n"
|
| 580 |
+
text += f" File: {chunk['file_path']}\n"
|
| 581 |
+
if chunk['entity_name']:
|
| 582 |
+
text += f" Entity: {chunk['entity_name']}\n"
|
| 583 |
+
|
| 584 |
+
if len(related) > 15:
|
| 585 |
+
text += f"\n... and {len(related) - 15} more\n"
|
| 586 |
+
|
| 587 |
+
return {
|
| 588 |
+
"chunk_id": chunk_id,
|
| 589 |
+
"relation_type": relation_type,
|
| 590 |
+
"total_related": len(related),
|
| 591 |
+
"related_chunks": related[:15],
|
| 592 |
+
"has_more": len(related) > 15,
|
| 593 |
+
"text": text
|
| 594 |
+
}
|
| 595 |
+
except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
|
| 596 |
+
return self._handle_error(e, "get_related_chunks")
|
| 597 |
+
except Exception as e:
|
| 598 |
+
return self._handle_error(e, "get_related_chunks")
|
| 599 |
+
|
| 600 |
+
@self.app.tool(
|
| 601 |
+
description="List all entities tracked in the knowledge graph, including their types, declaration, and usage counts."
|
| 602 |
+
)
|
| 603 |
+
@observe(as_type='tool')
|
| 604 |
+
async def list_all_entities(
|
| 605 |
+
limit: Annotated[int, "Maximum number of entities to return."] = 50
|
| 606 |
+
) -> dict:
|
| 607 |
+
"""List all entities tracked in the knowledge graph with their metadata."""
|
| 608 |
+
if not self.knowledge_graph.entities:
|
| 609 |
+
return {
|
| 610 |
+
"entities": [],
|
| 611 |
+
"text": "No entities found in the knowledge graph."
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
entities = []
|
| 615 |
+
for entity_name, info in list(self.knowledge_graph.entities.items())[:limit]:
|
| 616 |
+
entities.append({
|
| 617 |
+
"name": entity_name,
|
| 618 |
+
"types": info.get('type', ['Unknown']),
|
| 619 |
+
"declaration_count": len(info.get('declaring_chunk_ids', [])),
|
| 620 |
+
"usage_count": len(info.get('calling_chunk_ids', []))
|
| 621 |
+
})
|
| 622 |
+
|
| 623 |
+
text = f"All Entities ({len(self.knowledge_graph.entities)} total):\n\n"
|
| 624 |
+
for i, entity in enumerate(entities, 1):
|
| 625 |
+
text += f"{i}. {entity['name']}\n"
|
| 626 |
+
text += f" Types: {', '.join(entity['types'])}\n"
|
| 627 |
+
text += f" Declarations: {entity['declaration_count']}\n"
|
| 628 |
+
text += f" Usages: {entity['usage_count']}\n\n"
|
| 629 |
+
|
| 630 |
+
if len(self.knowledge_graph.entities) > limit:
|
| 631 |
+
text += f"... and {len(self.knowledge_graph.entities) - limit} more entities\n"
|
| 632 |
+
|
| 633 |
+
return {
|
| 634 |
+
"total_entities": len(self.knowledge_graph.entities),
|
| 635 |
+
"entities": entities,
|
| 636 |
+
"has_more": len(self.knowledge_graph.entities) > limit,
|
| 637 |
+
"text": text
|
| 638 |
+
}
|
| 639 |
+
|
| 640 |
+
# --- New Tools ---
|
| 641 |
+
@self.app.tool(
|
| 642 |
+
description="Show the diff between two code chunks or nodes by their IDs."
|
| 643 |
+
)
|
| 644 |
+
@observe(as_type='tool')
|
| 645 |
+
async def diff_chunks(
|
| 646 |
+
node_id_1: Annotated[str, "The ID of the first node/chunk."],
|
| 647 |
+
node_id_2: Annotated[str, "The ID of the second node/chunk."]
|
| 648 |
+
) -> dict:
|
| 649 |
+
try:
|
| 650 |
+
import difflib
|
| 651 |
+
self._validate_node_exists(node_id_1)
|
| 652 |
+
self._validate_node_exists(node_id_2)
|
| 653 |
+
|
| 654 |
+
g = self.knowledge_graph.graph
|
| 655 |
+
content1 = getattr(g.nodes[node_id_1]['data'], 'content', None)
|
| 656 |
+
content2 = getattr(g.nodes[node_id_2]['data'], 'content', None)
|
| 657 |
+
|
| 658 |
+
if not content1 or not content2:
|
| 659 |
+
raise InvalidInputError("One or both nodes have no content.")
|
| 660 |
+
|
| 661 |
+
diff = list(difflib.unified_diff(
|
| 662 |
+
content1.splitlines(), content2.splitlines(),
|
| 663 |
+
fromfile=node_id_1, tofile=node_id_2, lineterm=""
|
| 664 |
+
))
|
| 665 |
+
|
| 666 |
+
diff_text = "\n".join(diff) if diff else "No differences."
|
| 667 |
+
|
| 668 |
+
return {
|
| 669 |
+
"node_id_1": node_id_1,
|
| 670 |
+
"node_id_2": node_id_2,
|
| 671 |
+
"has_differences": bool(diff),
|
| 672 |
+
"diff": diff,
|
| 673 |
+
"text": diff_text
|
| 674 |
+
}
|
| 675 |
+
except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
|
| 676 |
+
return self._handle_error(e, "diff_chunks")
|
| 677 |
+
except Exception as e:
|
| 678 |
+
return self._handle_error(e, "diff_chunks")
|
| 679 |
+
|
| 680 |
+
@self.app.tool(
|
| 681 |
+
description="Show a tree view of the repository or a subtree starting from a given node ID."
|
| 682 |
+
)
|
| 683 |
+
@observe(as_type='tool')
|
| 684 |
+
async def print_tree(
|
| 685 |
+
root_id: Annotated[Optional[str], "The node ID to start the tree from (default: repo root)."] = 'root',
|
| 686 |
+
max_depth: Annotated[int, "Maximum depth to show."] = 3
|
| 687 |
+
) -> dict:
|
| 688 |
+
try:
|
| 689 |
+
g = self.knowledge_graph.graph
|
| 690 |
+
|
| 691 |
+
def build_tree(node_id, depth, tree_data):
|
| 692 |
+
if depth > max_depth:
|
| 693 |
+
return
|
| 694 |
+
node = g.nodes[node_id]['data']
|
| 695 |
+
node_info = {
|
| 696 |
+
"id": node_id,
|
| 697 |
+
"name": getattr(node, 'name', node_id),
|
| 698 |
+
"type": getattr(node, 'node_type', '?'),
|
| 699 |
+
"depth": depth,
|
| 700 |
+
"children": []
|
| 701 |
+
}
|
| 702 |
+
tree_data.append(node_info)
|
| 703 |
+
children = [t for s, t in g.out_edges(node_id)]
|
| 704 |
+
for child in children:
|
| 705 |
+
build_tree(child, depth + 1, node_info["children"])
|
| 706 |
+
|
| 707 |
+
def format_tree(tree_data):
|
| 708 |
+
result = ""
|
| 709 |
+
for node in tree_data:
|
| 710 |
+
result += " " * node["depth"] + f"- {node['name']} ({node['type']})\n"
|
| 711 |
+
for child in node["children"]:
|
| 712 |
+
result += format_subtree(child)
|
| 713 |
+
return result
|
| 714 |
+
|
| 715 |
+
def format_subtree(node):
|
| 716 |
+
result = " " * node["depth"] + f"- {node['name']} ({node['type']})\n"
|
| 717 |
+
for child in node["children"]:
|
| 718 |
+
result += format_subtree(child)
|
| 719 |
+
return result
|
| 720 |
+
|
| 721 |
+
if root_id is None:
|
| 722 |
+
roots = [n for n, d in g.nodes(data=True) if getattr(d['data'], 'node_type', None) in ('repo', 'directory', 'file')]
|
| 723 |
+
root_id = roots[0] if roots else list(g.nodes)[0]
|
| 724 |
+
|
| 725 |
+
self._validate_node_exists(root_id)
|
| 726 |
+
|
| 727 |
+
tree_data = []
|
| 728 |
+
build_tree(root_id, 0, tree_data)
|
| 729 |
+
|
| 730 |
+
return {
|
| 731 |
+
"root_id": root_id,
|
| 732 |
+
"max_depth": max_depth,
|
| 733 |
+
"tree": tree_data,
|
| 734 |
+
"text": format_tree(tree_data)
|
| 735 |
+
}
|
| 736 |
+
except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
|
| 737 |
+
return self._handle_error(e, "print_tree")
|
| 738 |
+
except Exception as e:
|
| 739 |
+
return self._handle_error(e, "print_tree")
|
| 740 |
+
|
| 741 |
+
@self.app.tool(
|
| 742 |
+
description="Show all relationships (calls, contains, etc.) for a given entity or node."
|
| 743 |
+
)
|
| 744 |
+
@observe(as_type='tool')
|
| 745 |
+
async def entity_relationships(
|
| 746 |
+
node_id: Annotated[str, "The node/entity ID to explore relationships for."]
|
| 747 |
+
) -> dict:
|
| 748 |
+
try:
|
| 749 |
+
self._validate_node_exists(node_id)
|
| 750 |
+
g = self.knowledge_graph.graph
|
| 751 |
+
|
| 752 |
+
incoming = []
|
| 753 |
+
outgoing = []
|
| 754 |
+
|
| 755 |
+
for source, target, data in g.in_edges(node_id, data=True):
|
| 756 |
+
incoming.append({
|
| 757 |
+
"source": source,
|
| 758 |
+
"target": target,
|
| 759 |
+
"relation": data.get('relation', '?')
|
| 760 |
+
})
|
| 761 |
+
|
| 762 |
+
for source, target, data in g.out_edges(node_id, data=True):
|
| 763 |
+
outgoing.append({
|
| 764 |
+
"source": source,
|
| 765 |
+
"target": target,
|
| 766 |
+
"relation": data.get('relation', '?')
|
| 767 |
+
})
|
| 768 |
+
|
| 769 |
+
text = f"Relationships for '{node_id}':\n"
|
| 770 |
+
for rel in incoming:
|
| 771 |
+
text += f"β {rel['source']} [{rel['relation']}]\n"
|
| 772 |
+
for rel in outgoing:
|
| 773 |
+
text += f"β {rel['target']} [{rel['relation']}]\n"
|
| 774 |
+
|
| 775 |
+
if not incoming and not outgoing:
|
| 776 |
+
text = "No relationships found."
|
| 777 |
+
|
| 778 |
+
return {
|
| 779 |
+
"node_id": node_id,
|
| 780 |
+
"incoming": incoming,
|
| 781 |
+
"outgoing": outgoing,
|
| 782 |
+
"incoming_count": len(incoming),
|
| 783 |
+
"outgoing_count": len(outgoing),
|
| 784 |
+
"text": text
|
| 785 |
+
}
|
| 786 |
+
except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
|
| 787 |
+
return self._handle_error(e, "entity_relationships")
|
| 788 |
+
except Exception as e:
|
| 789 |
+
return self._handle_error(e, "entity_relationships")
|
| 790 |
+
|
| 791 |
+
@self.app.tool(
|
| 792 |
+
description="Search for nodes/entities by type and name substring with fuzzy matching support. For entities, searches by entity_type (e.g., 'class', 'function', 'method'). For other nodes, searches by node_type (e.g., 'file', 'chunk', 'directory')."
|
| 793 |
+
)
|
| 794 |
+
@observe(as_type='tool')
|
| 795 |
+
async def search_by_type_and_name(
|
| 796 |
+
node_type: Annotated[str, "Type of node/entity (e.g., 'function', 'class', 'file', 'chunk', 'directory')."],
|
| 797 |
+
name_query: Annotated[str, "Substring to match in the name (case-insensitive, supports partial matches)."],
|
| 798 |
+
limit: Annotated[int, "Maximum results to return."] = 10,
|
| 799 |
+
fuzzy: Annotated[bool, "Enable fuzzy/partial matching (default: True)."] = True
|
| 800 |
+
) -> dict:
|
| 801 |
+
import re
|
| 802 |
+
try:
|
| 803 |
+
self._validate_positive_int(limit, "limit")
|
| 804 |
+
|
| 805 |
+
g = self.knowledge_graph.graph
|
| 806 |
+
matches = []
|
| 807 |
+
query_lower = name_query.lower()
|
| 808 |
+
|
| 809 |
+
# Build regex pattern for fuzzy matching
|
| 810 |
+
if fuzzy:
|
| 811 |
+
fuzzy_pattern = '.*'.join(re.escape(c) for c in query_lower)
|
| 812 |
+
fuzzy_regex = re.compile(fuzzy_pattern, re.IGNORECASE)
|
| 813 |
+
|
| 814 |
+
for nid, n in g.nodes(data=True):
|
| 815 |
+
node = n['data']
|
| 816 |
+
node_name = getattr(node, 'name', '')
|
| 817 |
+
|
| 818 |
+
if not node_name:
|
| 819 |
+
continue
|
| 820 |
+
|
| 821 |
+
# Check if name matches the query
|
| 822 |
+
name_matches = False
|
| 823 |
+
if fuzzy:
|
| 824 |
+
if query_lower in node_name.lower() or fuzzy_regex.search(node_name):
|
| 825 |
+
name_matches = True
|
| 826 |
+
else:
|
| 827 |
+
if query_lower in node_name.lower():
|
| 828 |
+
name_matches = True
|
| 829 |
+
|
| 830 |
+
if not name_matches:
|
| 831 |
+
continue
|
| 832 |
+
|
| 833 |
+
# Check type based on node_type
|
| 834 |
+
current_node_type = getattr(node, 'node_type', None)
|
| 835 |
+
|
| 836 |
+
# For entity nodes, check entity_type instead of node_type
|
| 837 |
+
if current_node_type == 'entity':
|
| 838 |
+
entity_type = getattr(node, 'entity_type', '')
|
| 839 |
+
|
| 840 |
+
# Fallback: if entity_type is empty, check the entities dictionary
|
| 841 |
+
if not entity_type and nid in self.knowledge_graph.entities:
|
| 842 |
+
entity_types = self.knowledge_graph.entities[nid].get('type', [])
|
| 843 |
+
entity_type = entity_types[0] if entity_types else ''
|
| 844 |
+
|
| 845 |
+
if entity_type and entity_type.lower() == node_type.lower():
|
| 846 |
+
score = 0 if query_lower == node_name.lower() else (1 if query_lower in node_name.lower() else 2)
|
| 847 |
+
matches.append({
|
| 848 |
+
"id": nid,
|
| 849 |
+
"name": node_name,
|
| 850 |
+
"type": f"entity ({entity_type})",
|
| 851 |
+
"content": getattr(node, 'content', None),
|
| 852 |
+
"score": score
|
| 853 |
+
})
|
| 854 |
+
# For other nodes, check node_type directly
|
| 855 |
+
elif current_node_type == node_type:
|
| 856 |
+
score = 0 if query_lower == node_name.lower() else (1 if query_lower in node_name.lower() else 2)
|
| 857 |
+
matches.append({
|
| 858 |
+
"id": nid,
|
| 859 |
+
"name": node_name,
|
| 860 |
+
"type": current_node_type,
|
| 861 |
+
"content": getattr(node, 'content', None),
|
| 862 |
+
"score": score
|
| 863 |
+
})
|
| 864 |
+
|
| 865 |
+
# Sort by match score (best matches first) and limit results
|
| 866 |
+
matches.sort(key=lambda x: (x['score'], x['name'].lower()))
|
| 867 |
+
matches = matches[:limit]
|
| 868 |
+
|
| 869 |
+
if not matches:
|
| 870 |
+
return {
|
| 871 |
+
"node_type": node_type,
|
| 872 |
+
"name_query": name_query,
|
| 873 |
+
"matches": [],
|
| 874 |
+
"text": f"No matches for type '{node_type}' and name containing '{name_query}'."
|
| 875 |
+
}
|
| 876 |
+
|
| 877 |
+
text = f"Matches for type '{node_type}' and name '{name_query}' ({len(matches)} results):\n"
|
| 878 |
+
for match in matches:
|
| 879 |
+
text += f"- {match['id']}: {match['name']} [{match['type']}]\n"
|
| 880 |
+
|
| 881 |
+
return {
|
| 882 |
+
"node_type": node_type,
|
| 883 |
+
"name_query": name_query,
|
| 884 |
+
"count": len(matches),
|
| 885 |
+
"matches": matches,
|
| 886 |
+
"text": text
|
| 887 |
+
}
|
| 888 |
+
except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
|
| 889 |
+
return self._handle_error(e, "search_by_type_and_name")
|
| 890 |
+
except Exception as e:
|
| 891 |
+
return self._handle_error(e, "search_by_type_and_name")
|
| 892 |
+
|
| 893 |
+
@self.app.tool(
|
| 894 |
+
description="Get the full content of a code chunk along with its surrounding chunks (previous and next)."
|
| 895 |
+
)
|
| 896 |
+
@observe(as_type='tool')
|
| 897 |
+
async def get_chunk_context(
|
| 898 |
+
node_id: Annotated[str, "The node/chunk ID to get context for."]
|
| 899 |
+
) -> dict:
|
| 900 |
+
from .utils.chunk_utils import organize_chunks_by_file_name, join_organized_chunks
|
| 901 |
+
try:
|
| 902 |
+
self._validate_node_exists(node_id)
|
| 903 |
+
|
| 904 |
+
g = self.knowledge_graph.graph
|
| 905 |
+
current_chunk = g.nodes[node_id]['data']
|
| 906 |
+
previous_chunk = self.knowledge_graph.get_previous_chunk(node_id)
|
| 907 |
+
next_chunk = self.knowledge_graph.get_next_chunk(node_id)
|
| 908 |
+
|
| 909 |
+
# Collect all chunks (previous, current, next)
|
| 910 |
+
chunks = []
|
| 911 |
+
prev_info = None
|
| 912 |
+
next_info = None
|
| 913 |
+
current_info = {
|
| 914 |
+
"id": node_id,
|
| 915 |
+
"content": getattr(current_chunk, 'content', '')
|
| 916 |
+
}
|
| 917 |
+
|
| 918 |
+
if previous_chunk:
|
| 919 |
+
prev_info = {
|
| 920 |
+
"id": previous_chunk.id,
|
| 921 |
+
"content": previous_chunk.content
|
| 922 |
+
}
|
| 923 |
+
chunks.append(previous_chunk)
|
| 924 |
+
|
| 925 |
+
chunks.append(current_chunk)
|
| 926 |
+
|
| 927 |
+
if next_chunk:
|
| 928 |
+
next_info = {
|
| 929 |
+
"id": next_chunk.id,
|
| 930 |
+
"content": next_chunk.content
|
| 931 |
+
}
|
| 932 |
+
chunks.append(next_chunk)
|
| 933 |
+
|
| 934 |
+
# Organize and join chunks
|
| 935 |
+
organized = organize_chunks_by_file_name(chunks)
|
| 936 |
+
full_content = join_organized_chunks(organized)
|
| 937 |
+
|
| 938 |
+
return {
|
| 939 |
+
"node_id": node_id,
|
| 940 |
+
"current_chunk": current_info,
|
| 941 |
+
"previous_chunk": prev_info,
|
| 942 |
+
"next_chunk": next_info,
|
| 943 |
+
"text": full_content
|
| 944 |
+
}
|
| 945 |
+
except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
|
| 946 |
+
return self._handle_error(e, "get_chunk_context")
|
| 947 |
+
except Exception as e:
|
| 948 |
+
return self._handle_error(e, "get_chunk_context")
|
| 949 |
+
|
| 950 |
+
@self.app.tool(
|
| 951 |
+
description="Get statistics for a file or directory: number of entities, lines, chunks, etc."
|
| 952 |
+
)
|
| 953 |
+
@observe(as_type='tool')
|
| 954 |
+
async def get_file_stats(
|
| 955 |
+
path: Annotated[str, "The file or directory path to get statistics for."]
|
| 956 |
+
) -> dict:
|
| 957 |
+
try:
|
| 958 |
+
g = self.knowledge_graph.graph
|
| 959 |
+
nodes = [n for n, d in g.nodes(data=True) if getattr(d['data'], 'path', None) == path]
|
| 960 |
+
|
| 961 |
+
if not nodes:
|
| 962 |
+
raise NodeNotFoundError(f"No nodes found for path '{path}'.")
|
| 963 |
+
|
| 964 |
+
stats = []
|
| 965 |
+
text = f"Statistics for '{path}':\n"
|
| 966 |
+
|
| 967 |
+
for node_id in nodes:
|
| 968 |
+
node = g.nodes[node_id]['data']
|
| 969 |
+
content = getattr(node, 'content', '')
|
| 970 |
+
declared = getattr(node, 'declared_entities', [])
|
| 971 |
+
called = getattr(node, 'called_entities', [])
|
| 972 |
+
chunks = [t for s, t in g.out_edges(node_id) if getattr(g.nodes[t]['data'], 'node_type', None) == 'chunk']
|
| 973 |
+
|
| 974 |
+
declared_list = []
|
| 975 |
+
for entity in declared[:10]:
|
| 976 |
+
if isinstance(entity, dict):
|
| 977 |
+
declared_list.append({
|
| 978 |
+
"name": entity.get('name', '?'),
|
| 979 |
+
"type": entity.get('type', '?')
|
| 980 |
+
})
|
| 981 |
+
else:
|
| 982 |
+
declared_list.append({"name": str(entity), "type": "Unknown"})
|
| 983 |
+
|
| 984 |
+
called_list = [str(entity) for entity in called[:10]]
|
| 985 |
+
|
| 986 |
+
node_stats = {
|
| 987 |
+
"node_id": node_id,
|
| 988 |
+
"node_type": getattr(node, 'node_type', '?'),
|
| 989 |
+
"lines": len(content.splitlines()) if content else 0,
|
| 990 |
+
"declared_entities_count": len(declared),
|
| 991 |
+
"declared_entities": declared_list,
|
| 992 |
+
"called_entities_count": len(called),
|
| 993 |
+
"called_entities": called_list,
|
| 994 |
+
"chunks_count": len(chunks),
|
| 995 |
+
"has_more_declared": len(declared) > 10,
|
| 996 |
+
"has_more_called": len(called) > 10
|
| 997 |
+
}
|
| 998 |
+
stats.append(node_stats)
|
| 999 |
+
|
| 1000 |
+
text += f"- Node: {node_id} ({node_stats['node_type']})\n"
|
| 1001 |
+
text += f" Lines: {node_stats['lines']}\n"
|
| 1002 |
+
|
| 1003 |
+
if declared_list:
|
| 1004 |
+
text += f" Declared entities ({len(declared)}):\n"
|
| 1005 |
+
for entity in declared_list:
|
| 1006 |
+
text += f" - {entity['name']} ({entity['type']})\n"
|
| 1007 |
+
if len(declared) > 10:
|
| 1008 |
+
text += f" ... and {len(declared) - 10} more\n"
|
| 1009 |
+
else:
|
| 1010 |
+
text += f" Declared entities: 0\n"
|
| 1011 |
+
|
| 1012 |
+
if called_list:
|
| 1013 |
+
text += f" Called entities ({len(called)}):\n"
|
| 1014 |
+
for entity in called_list:
|
| 1015 |
+
text += f" - {entity}\n"
|
| 1016 |
+
if len(called) > 10:
|
| 1017 |
+
text += f" ... and {len(called) - 10} more\n"
|
| 1018 |
+
else:
|
| 1019 |
+
text += f" Called entities: 0\n"
|
| 1020 |
+
|
| 1021 |
+
text += f" Chunks: {len(chunks)}\n"
|
| 1022 |
+
|
| 1023 |
+
return {
|
| 1024 |
+
"path": path,
|
| 1025 |
+
"nodes": stats,
|
| 1026 |
+
"text": text
|
| 1027 |
+
}
|
| 1028 |
+
except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
|
| 1029 |
+
return self._handle_error(e, "get_file_stats")
|
| 1030 |
+
except Exception as e:
|
| 1031 |
+
return self._handle_error(e, "get_file_stats")
|
| 1032 |
+
# --- End New Tools ---
|
| 1033 |
+
@self.app.tool(
|
| 1034 |
+
description="Search for file names in the repository using a regular expression pattern."
|
| 1035 |
+
)
|
| 1036 |
+
@observe(as_type='tool')
|
| 1037 |
+
async def search_file_names_by_regex(
|
| 1038 |
+
pattern: Annotated[str, "The regular expression pattern to match file names."]
|
| 1039 |
+
) -> dict:
|
| 1040 |
+
"""Search for file names matching a regex pattern."""
|
| 1041 |
+
import re
|
| 1042 |
+
g = self.knowledge_graph.graph
|
| 1043 |
+
|
| 1044 |
+
try:
|
| 1045 |
+
regex = re.compile(pattern)
|
| 1046 |
+
except re.error as e:
|
| 1047 |
+
return {"error": f"Invalid regex pattern: {str(e)}"}
|
| 1048 |
+
|
| 1049 |
+
matches = []
|
| 1050 |
+
for node_id, node_attrs in g.nodes(data=True):
|
| 1051 |
+
node = node_attrs['data']
|
| 1052 |
+
if getattr(node, 'node_type', None) == 'file':
|
| 1053 |
+
file_name = getattr(node, 'name', '') or getattr(node, 'path', '')
|
| 1054 |
+
if regex.search(file_name):
|
| 1055 |
+
matches.append({
|
| 1056 |
+
"node_id": node_id,
|
| 1057 |
+
"file_name": file_name
|
| 1058 |
+
})
|
| 1059 |
+
|
| 1060 |
+
if not matches:
|
| 1061 |
+
return {
|
| 1062 |
+
"pattern": pattern,
|
| 1063 |
+
"matches": [],
|
| 1064 |
+
"text": f"No file names matched the pattern: '{pattern}'"
|
| 1065 |
+
}
|
| 1066 |
+
|
| 1067 |
+
text = f"Files matching pattern '{pattern}':\n"
|
| 1068 |
+
for match in matches[:20]:
|
| 1069 |
+
text += f"- {match['file_name']} (node ID: {match['node_id']})\n"
|
| 1070 |
+
|
| 1071 |
+
if len(matches) > 20:
|
| 1072 |
+
text += f"... and {len(matches) - 20} more\n"
|
| 1073 |
+
|
| 1074 |
+
return {
|
| 1075 |
+
"pattern": pattern,
|
| 1076 |
+
"count": len(matches),
|
| 1077 |
+
"matches": matches[:20],
|
| 1078 |
+
"has_more": len(matches) > 20,
|
| 1079 |
+
"text": text
|
| 1080 |
+
}
|
| 1081 |
+
|
| 1082 |
+
@self.app.tool(
|
| 1083 |
+
description="Find the shortest path between two nodes in the knowledge graph."
|
| 1084 |
+
)
|
| 1085 |
+
@observe(as_type='tool')
|
| 1086 |
+
async def find_path(
|
| 1087 |
+
source_id: Annotated[str, "The ID of the source node."],
|
| 1088 |
+
target_id: Annotated[str, "The ID of the target node."],
|
| 1089 |
+
max_depth: Annotated[int, "Maximum depth to search for a path."] = 5
|
| 1090 |
+
) -> dict:
|
| 1091 |
+
"""Find shortest path between two nodes."""
|
| 1092 |
+
return self.knowledge_graph.find_path(source_id, target_id, max_depth)
|
| 1093 |
+
|
| 1094 |
+
@self.app.tool(
|
| 1095 |
+
description="Extract a subgraph around a node up to a specified depth, optionally filtering by edge types."
|
| 1096 |
+
)
|
| 1097 |
+
@observe(as_type='tool')
|
| 1098 |
+
async def get_subgraph(
|
| 1099 |
+
node_id: Annotated[str, "The ID of the central node."],
|
| 1100 |
+
depth: Annotated[int, "The depth/radius of the subgraph to extract."] = 2,
|
| 1101 |
+
edge_types: Annotated[Optional[list], "Optional list of edge types to include (e.g., ['calls', 'contains'])."] = None
|
| 1102 |
+
) -> dict:
|
| 1103 |
+
"""Extract a subgraph around a node."""
|
| 1104 |
+
return self.knowledge_graph.get_subgraph(node_id, depth, edge_types)
|
| 1105 |
+
|
| 1106 |
+
def run(self, **kwargs):
|
| 1107 |
+
self.app.run(**kwargs)
|
RepoKnowledgeGraphLib/ModelService.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from openai import OpenAI, AsyncOpenAI
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
import os
|
| 5 |
+
import logging
|
| 6 |
+
from tenacity import retry, stop_after_attempt, wait_fixed
|
| 7 |
+
import httpx
|
| 8 |
+
from sentence_transformers import SentenceTransformer
|
| 9 |
+
|
| 10 |
+
# Optional torch import for CUDA detection
|
| 11 |
+
try:
|
| 12 |
+
import torch
|
| 13 |
+
_TORCH_AVAILABLE = True
|
| 14 |
+
except Exception:
|
| 15 |
+
torch = None
|
| 16 |
+
_TORCH_AVAILABLE = False
|
| 17 |
+
|
| 18 |
+
from .utils.logger_utils import setup_logger
|
| 19 |
+
|
| 20 |
+
LOGGER_NAME = "MODEL_SERVICE_LOGGER"
|
| 21 |
+
# GENERATION ENV VARIABLES (defaults)
|
| 22 |
+
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", 'http://0.0.0.0:8000/v1')
|
| 23 |
+
OPENAI_TOKEN = os.getenv("OPENAI_TOKEN", 'no-need')
|
| 24 |
+
MODEL_NAME = os.getenv('MODEL_NAME', "meta-llama/Llama-3.2-3B-Instruct")
|
| 25 |
+
# EMBED ENV VARIABLES (defaults)
|
| 26 |
+
OPENAI_EMBED_BASE_URL = os.getenv("OPENAI_EMBED_BASE_URL", 'http://0.0.0.0:8001/v1')
|
| 27 |
+
OPENAI_EMBED_TOKEN = os.getenv("OPENAI_EMBED_TOKEN", 'no-need')
|
| 28 |
+
EMBED_MODEL_NAME = os.getenv('EMBED_MODEL_NAME', "Alibaba-NLP/gte-Qwen2-1.5B-instruct")
|
| 29 |
+
|
| 30 |
+
# Additional ENV defaults requested
|
| 31 |
+
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 2048))
|
| 32 |
+
TEMPERATURE = float(os.getenv("TEMPERATURE", 0.2))
|
| 33 |
+
TOP_P = float(os.getenv("TOP_P", 0.95))
|
| 34 |
+
FREQUENCY_PENALTY = float(os.getenv("FREQUENCY_PENALTY", 0))
|
| 35 |
+
PRESENCE_PENALTY = float(os.getenv("PRESENCE_PENALTY", 0))
|
| 36 |
+
EMBEDDING_MODEL_URL = os.getenv("EMBEDDING_MODEL_URL", "")
|
| 37 |
+
EMBEDDING_MODEL_API_KEY = os.getenv("EMBEDDING_MODEL_API_KEY", "no_need")
|
| 38 |
+
EMBEDDING_NUMBER_DIMENSIONS = int(os.getenv("EMBEDDING_NUMBER_DIMENSIONS", 1024))
|
| 39 |
+
|
| 40 |
+
STOP_AFTER_ATTEMPT = int(os.getenv("STOP_AFTER_ATTEMPT", 5))
|
| 41 |
+
WAIT_BETWEEN_RETRIES = int(os.getenv("WAIT_BETWEEN_RETRIES", 2))
|
| 42 |
+
REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", 240))
|
| 43 |
+
|
| 44 |
+
# Note: module-level clients remain for backward compatibility but instances will create their own if timeout is overridden.
|
| 45 |
+
long_timeout_client = httpx.Client(timeout=REQUEST_TIMEOUT)
|
| 46 |
+
long_timeout_async_client = httpx.AsyncClient(timeout=REQUEST_TIMEOUT)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ModelServiceInterface(ABC):
|
| 50 |
+
"""
|
| 51 |
+
Abstract base class defining the interface for model services.
|
| 52 |
+
All model services should implement these methods.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
# accept model_kwargs so variables can be overridden at runtime
|
| 56 |
+
def __init__(self, model_name: str = None, model_kwargs: dict = None):
|
| 57 |
+
setup_logger(LOGGER_NAME)
|
| 58 |
+
self.logger = logging.getLogger(LOGGER_NAME)
|
| 59 |
+
|
| 60 |
+
model_kwargs = model_kwargs or {}
|
| 61 |
+
|
| 62 |
+
# allow overriding via model_kwargs; fall back to module-level defaults
|
| 63 |
+
self.openai_base_url = model_kwargs.get("OPENAI_BASE_URL", OPENAI_BASE_URL)
|
| 64 |
+
self.openai_token = model_kwargs.get("OPENAI_TOKEN", OPENAI_TOKEN)
|
| 65 |
+
# model_name param takes precedence, then model_kwargs then default env
|
| 66 |
+
self.model_name = model_name or model_kwargs.get("MODEL_NAME", MODEL_NAME)
|
| 67 |
+
|
| 68 |
+
# embed defaults (may be overridden by subclasses or model_kwargs)
|
| 69 |
+
self.openai_embed_base_url = model_kwargs.get("OPENAI_EMBED_BASE_URL", OPENAI_EMBED_BASE_URL)
|
| 70 |
+
self.openai_embed_token = model_kwargs.get("OPENAI_EMBED_TOKEN", OPENAI_EMBED_TOKEN)
|
| 71 |
+
self.embed_model_name = model_kwargs.get("EMBED_MODEL_NAME", EMBED_MODEL_NAME)
|
| 72 |
+
|
| 73 |
+
# other configurable parameters
|
| 74 |
+
self.max_tokens = int(model_kwargs.get("MAX_TOKENS", MAX_TOKENS))
|
| 75 |
+
self.temperature = float(model_kwargs.get("TEMPERATURE", TEMPERATURE))
|
| 76 |
+
self.top_p = float(model_kwargs.get("TOP_P", TOP_P))
|
| 77 |
+
self.frequency_penalty = float(model_kwargs.get("FREQUENCY_PENALTY", FREQUENCY_PENALTY))
|
| 78 |
+
self.presence_penalty = float(model_kwargs.get("PRESENCE_PENALTY", PRESENCE_PENALTY))
|
| 79 |
+
self.embedding_model_url = model_kwargs.get("EMBEDDING_MODEL_URL", EMBEDDING_MODEL_URL)
|
| 80 |
+
self.embedding_model_api_key = model_kwargs.get("EMBEDDING_MODEL_API_KEY", EMBEDDING_MODEL_API_KEY)
|
| 81 |
+
self.embedding_number_dimensions = int(model_kwargs.get("EMBEDDING_NUMBER_DIMENSIONS", EMBEDDING_NUMBER_DIMENSIONS))
|
| 82 |
+
|
| 83 |
+
self.stop_after_attempt = int(model_kwargs.get("STOP_AFTER_ATTEMPT", STOP_AFTER_ATTEMPT))
|
| 84 |
+
self.wait_between_retries = int(model_kwargs.get("WAIT_BETWEEN_RETRIES", WAIT_BETWEEN_RETRIES))
|
| 85 |
+
request_timeout = int(model_kwargs.get("REQUEST_TIMEOUT", REQUEST_TIMEOUT))
|
| 86 |
+
|
| 87 |
+
# create per-instance httpx clients in case REQUEST_TIMEOUT was overridden
|
| 88 |
+
self.long_timeout_client = httpx.Client(timeout=request_timeout)
|
| 89 |
+
self.long_timeout_async_client = httpx.AsyncClient(timeout=request_timeout)
|
| 90 |
+
|
| 91 |
+
# Initialize query client (shared by all implementations)
|
| 92 |
+
self.client = OpenAI(
|
| 93 |
+
base_url=self.openai_base_url,
|
| 94 |
+
api_key=self.openai_token,
|
| 95 |
+
http_client=self.long_timeout_client,
|
| 96 |
+
)
|
| 97 |
+
self.async_client = AsyncOpenAI(
|
| 98 |
+
base_url=self.openai_base_url,
|
| 99 |
+
api_key=self.openai_token,
|
| 100 |
+
http_client=self.long_timeout_async_client,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
@retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
|
| 104 |
+
def query(self, prompt: str, model_name: str) -> str:
|
| 105 |
+
"""Query the model with a prompt."""
|
| 106 |
+
if model_name is None:
|
| 107 |
+
model_name = self.model_name
|
| 108 |
+
completion = self.client.chat.completions.create(
|
| 109 |
+
model=model_name,
|
| 110 |
+
messages=[
|
| 111 |
+
{"role": "user", "content": prompt}
|
| 112 |
+
]
|
| 113 |
+
)
|
| 114 |
+
return completion.choices[0].message.content
|
| 115 |
+
|
| 116 |
+
@retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
|
| 117 |
+
def query_with_instructions(self, prompt: str, instructions: str, model_name: str) -> str:
|
| 118 |
+
"""Query the model with additional system instructions."""
|
| 119 |
+
if model_name is None:
|
| 120 |
+
model_name = self.model_name
|
| 121 |
+
completion = self.client.chat.completions.create(
|
| 122 |
+
model=model_name,
|
| 123 |
+
messages=[
|
| 124 |
+
{"role": "system", "content": instructions},
|
| 125 |
+
{"role": "user", "content": prompt}
|
| 126 |
+
]
|
| 127 |
+
)
|
| 128 |
+
return completion.choices[0].message.content
|
| 129 |
+
|
| 130 |
+
@retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
|
| 131 |
+
async def query_async(self, prompt: str, model_name: str ) -> str:
|
| 132 |
+
"""Async version of query."""
|
| 133 |
+
if model_name is None:
|
| 134 |
+
model_name = self.model_name
|
| 135 |
+
completion = await self.async_client.chat.completions.create(
|
| 136 |
+
model=model_name,
|
| 137 |
+
messages=[
|
| 138 |
+
{"role": "user", "content": prompt}
|
| 139 |
+
]
|
| 140 |
+
)
|
| 141 |
+
return completion.choices[0].message.content
|
| 142 |
+
|
| 143 |
+
@retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
|
| 144 |
+
async def query_with_instructions_async(self, prompt: str, instructions: str, model_name: str) -> str:
|
| 145 |
+
"""Async version of query with instructions."""
|
| 146 |
+
if model_name is None:
|
| 147 |
+
model_name = self.model_name
|
| 148 |
+
completion = await self.async_client.chat.completions.create(
|
| 149 |
+
model=model_name,
|
| 150 |
+
messages=[
|
| 151 |
+
{"role": "system", "content": instructions},
|
| 152 |
+
{"role": "user", "content": prompt}
|
| 153 |
+
]
|
| 154 |
+
)
|
| 155 |
+
return completion.choices[0].message.content
|
| 156 |
+
|
| 157 |
+
@abstractmethod
|
| 158 |
+
def embed(self, text_to_embed: str) -> list:
|
| 159 |
+
"""Embed text using the configured embedding model."""
|
| 160 |
+
pass
|
| 161 |
+
|
| 162 |
+
@abstractmethod
|
| 163 |
+
async def embed_async(self, text_to_embed: str) -> list:
|
| 164 |
+
"""Async version of embed."""
|
| 165 |
+
pass
|
| 166 |
+
|
| 167 |
+
@abstractmethod
|
| 168 |
+
def embed_chunk_code(self, code_to_embed: str) -> list:
|
| 169 |
+
"""Embed code chunk for storage/indexing."""
|
| 170 |
+
pass
|
| 171 |
+
|
| 172 |
+
@abstractmethod
|
| 173 |
+
def embed_query(self, query_to_embed: str) -> list:
|
| 174 |
+
"""Embed query for retrieval."""
|
| 175 |
+
pass
|
| 176 |
+
|
| 177 |
+
@abstractmethod
|
| 178 |
+
def embed_batch(self, texts_to_embed: list[str]) -> list[list]:
|
| 179 |
+
"""Embed multiple texts in a batch for better performance."""
|
| 180 |
+
pass
|
| 181 |
+
|
| 182 |
+
@abstractmethod
|
| 183 |
+
def embed_chunk_code_batch(self, codes_to_embed: list[str]) -> list[list]:
|
| 184 |
+
"""Embed multiple code chunks in a batch for storage/indexing."""
|
| 185 |
+
pass
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class OpenAIModelService(ModelServiceInterface):
|
| 189 |
+
"""
|
| 190 |
+
Model service that uses OpenAI client for both queries and embeddings.
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
def __init__(self, model_name: str = None, embed_model_name: str = None, model_kwargs: dict = None):
|
| 194 |
+
# forward model_kwargs to base so it can set instance-wide config
|
| 195 |
+
super().__init__(model_name=model_name, model_kwargs=model_kwargs)
|
| 196 |
+
|
| 197 |
+
# allow override of embed model name via param or model_kwargs
|
| 198 |
+
model_kwargs = model_kwargs or {}
|
| 199 |
+
self.embed_model_name = embed_model_name or model_kwargs.get("EMBED_MODEL_NAME", self.embed_model_name)
|
| 200 |
+
|
| 201 |
+
# embed client should use the instance-level embed base/token
|
| 202 |
+
self.embed_client = OpenAI(
|
| 203 |
+
base_url=model_kwargs.get("OPENAI_EMBED_BASE_URL", self.openai_embed_base_url),
|
| 204 |
+
api_key=model_kwargs.get("OPENAI_EMBED_TOKEN", self.openai_embed_token),
|
| 205 |
+
http_client=self.long_timeout_client,
|
| 206 |
+
)
|
| 207 |
+
self.async_embed_client = AsyncOpenAI(
|
| 208 |
+
base_url=model_kwargs.get("OPENAI_EMBED_BASE_URL", self.openai_embed_base_url),
|
| 209 |
+
api_key=model_kwargs.get("OPENAI_EMBED_TOKEN", self.openai_embed_token),
|
| 210 |
+
http_client=self.long_timeout_async_client,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
@retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
|
| 214 |
+
def embed(self, text_to_embed: str) -> list:
|
| 215 |
+
"""Embed text using OpenAI embeddings API."""
|
| 216 |
+
response = self.embed_client.embeddings.create(
|
| 217 |
+
input=text_to_embed,
|
| 218 |
+
model=self.embed_model_name,
|
| 219 |
+
)
|
| 220 |
+
return response.data[0].embedding
|
| 221 |
+
|
| 222 |
+
@retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
|
| 223 |
+
async def embed_async(self, text_to_embed: str) -> list:
|
| 224 |
+
"""Async version of embed using OpenAI embeddings API."""
|
| 225 |
+
response = await self.async_embed_client.embeddings.create(
|
| 226 |
+
input=text_to_embed,
|
| 227 |
+
model=self.embed_model_name,
|
| 228 |
+
)
|
| 229 |
+
return response.data[0].embedding
|
| 230 |
+
|
| 231 |
+
def embed_chunk_code(self, code_to_embed: str) -> list:
|
| 232 |
+
"""Embed code chunk using OpenAI embeddings API (same as embed)."""
|
| 233 |
+
return self.embed(code_to_embed)
|
| 234 |
+
|
| 235 |
+
def embed_query(self, query_to_embed: str) -> list:
|
| 236 |
+
"""Embed query using OpenAI embeddings API (same as embed)."""
|
| 237 |
+
return self.embed(query_to_embed)
|
| 238 |
+
|
| 239 |
+
@retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
|
| 240 |
+
def embed_batch(self, texts_to_embed: list[str]) -> list[list]:
|
| 241 |
+
"""Embed multiple texts in a batch using OpenAI embeddings API."""
|
| 242 |
+
if not texts_to_embed:
|
| 243 |
+
return []
|
| 244 |
+
response = self.embed_client.embeddings.create(
|
| 245 |
+
input=texts_to_embed,
|
| 246 |
+
model=self.embed_model_name,
|
| 247 |
+
)
|
| 248 |
+
return [item.embedding for item in response.data]
|
| 249 |
+
|
| 250 |
+
def embed_chunk_code_batch(self, codes_to_embed: list[str]) -> list[list]:
|
| 251 |
+
"""Embed multiple code chunks in a batch using OpenAI embeddings API."""
|
| 252 |
+
return self.embed_batch(codes_to_embed)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class SentenceTransformersModelService(ModelServiceInterface):
|
| 256 |
+
"""
|
| 257 |
+
Model service that uses OpenAI client for queries and SentenceTransformers for embeddings.
|
| 258 |
+
Optimized for high-throughput batch embedding with GPU support.
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
def __init__(self, model_name: str = None, embed_model_name: str = None, model_kwargs: dict = None, skip_embedder: bool = False):
|
| 262 |
+
super().__init__(model_name=model_name, model_kwargs=model_kwargs)
|
| 263 |
+
model_kwargs = model_kwargs or {}
|
| 264 |
+
# embed_model_name may be overridden by model_kwargs
|
| 265 |
+
self.embed_model_name = embed_model_name or model_kwargs.get("EMBED_MODEL_NAME", self.embed_model_name)
|
| 266 |
+
self.skip_embedder = skip_embedder
|
| 267 |
+
self.embedding_model = None
|
| 268 |
+
|
| 269 |
+
if skip_embedder:
|
| 270 |
+
self.logger.info('Skipping embedder initialization (keyword-only mode)')
|
| 271 |
+
self.device = "cpu"
|
| 272 |
+
self.encode_batch_size = 32
|
| 273 |
+
return
|
| 274 |
+
|
| 275 |
+
# Debug GPU detection
|
| 276 |
+
self.logger.info(f'PyTorch available: {_TORCH_AVAILABLE}')
|
| 277 |
+
if _TORCH_AVAILABLE:
|
| 278 |
+
self.logger.info(f'CUDA available: {torch.cuda.is_available()}')
|
| 279 |
+
self.logger.info(f'CUDA device count: {torch.cuda.device_count()}')
|
| 280 |
+
if torch.cuda.is_available():
|
| 281 |
+
self.logger.info(f'CUDA device name: {torch.cuda.get_device_name(0)}')
|
| 282 |
+
|
| 283 |
+
# Select device: prefer CUDA if available
|
| 284 |
+
self.device = "cuda" if (_TORCH_AVAILABLE and torch.cuda.is_available()) else "cpu"
|
| 285 |
+
self.logger.info(f'Initializing SentenceTransformer on device: {self.device}')
|
| 286 |
+
|
| 287 |
+
# Set batch size based on device and available memory
|
| 288 |
+
# Larger batch sizes significantly improve GPU throughput
|
| 289 |
+
self.encode_batch_size = int(model_kwargs.get("ENCODE_BATCH_SIZE", 64 if self.device == "cuda" else 32))
|
| 290 |
+
|
| 291 |
+
# Show CUDA memory info if available
|
| 292 |
+
if self.device == "cuda" and _TORCH_AVAILABLE:
|
| 293 |
+
try:
|
| 294 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 295 |
+
self.logger.info(f'GPU memory available: {gpu_memory:.2f} GB')
|
| 296 |
+
# Adjust batch size based on available GPU memory
|
| 297 |
+
if gpu_memory > 16:
|
| 298 |
+
self.encode_batch_size = max(self.encode_batch_size, 128)
|
| 299 |
+
elif gpu_memory > 8:
|
| 300 |
+
self.encode_batch_size = max(self.encode_batch_size, 64)
|
| 301 |
+
except Exception as e:
|
| 302 |
+
self.logger.warning(f'Could not get GPU memory info: {e}')
|
| 303 |
+
|
| 304 |
+
self.logger.info(f'Using encode batch size: {self.encode_batch_size}')
|
| 305 |
+
|
| 306 |
+
# Initialize embedding model on the chosen device with performance optimizations
|
| 307 |
+
self.embedding_model = SentenceTransformer(
|
| 308 |
+
self.embed_model_name,
|
| 309 |
+
trust_remote_code=True,
|
| 310 |
+
device=self.device
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# Enable half precision for faster inference on CUDA
|
| 314 |
+
if self.device == "cuda" and _TORCH_AVAILABLE:
|
| 315 |
+
try:
|
| 316 |
+
# Check if model supports half precision
|
| 317 |
+
self.embedding_model.half()
|
| 318 |
+
self.logger.info('Enabled half precision (FP16) for faster GPU inference')
|
| 319 |
+
except Exception as e:
|
| 320 |
+
self.logger.warning(f'Could not enable half precision: {e}')
|
| 321 |
+
|
| 322 |
+
def _check_embedder(self):
|
| 323 |
+
"""Check if embedder is available, raise error if not."""
|
| 324 |
+
if self.skip_embedder or self.embedding_model is None:
|
| 325 |
+
raise RuntimeError(
|
| 326 |
+
"Embedding model not initialized. This model service was created with skip_embedder=True "
|
| 327 |
+
"(keyword-only mode). To use embeddings, set index_type to 'hybrid' or 'embedding-only'."
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
def embed(self, text_to_embed: str) -> list:
|
| 331 |
+
"""Embed text using SentenceTransformers."""
|
| 332 |
+
self._check_embedder()
|
| 333 |
+
embeddings = self.embedding_model.encode(
|
| 334 |
+
[text_to_embed],
|
| 335 |
+
convert_to_numpy=True,
|
| 336 |
+
show_progress_bar=False
|
| 337 |
+
)
|
| 338 |
+
return embeddings[0].tolist() if hasattr(embeddings[0], 'tolist') else list(embeddings[0])
|
| 339 |
+
|
| 340 |
+
async def embed_async(self, text_to_embed: str) -> list:
|
| 341 |
+
"""
|
| 342 |
+
Async version of embed using SentenceTransformers.
|
| 343 |
+
Note: SentenceTransformers doesn't have native async support,
|
| 344 |
+
so this runs synchronously but maintains the async interface.
|
| 345 |
+
"""
|
| 346 |
+
return self.embed(text_to_embed)
|
| 347 |
+
|
| 348 |
+
def embed_chunk_code(self, code_to_embed: str) -> list:
|
| 349 |
+
"""Embed code chunk using SentenceTransformers (no special prompt)."""
|
| 350 |
+
self._check_embedder()
|
| 351 |
+
self.logger.debug(f'Embedding code using {self.embed_model_name}')
|
| 352 |
+
embeddings = self.embedding_model.encode(
|
| 353 |
+
[code_to_embed],
|
| 354 |
+
convert_to_numpy=True,
|
| 355 |
+
show_progress_bar=False
|
| 356 |
+
)
|
| 357 |
+
return embeddings[0].tolist() if hasattr(embeddings[0], 'tolist') else list(embeddings[0])
|
| 358 |
+
|
| 359 |
+
def embed_query(self, query_to_embed: str) -> list:
|
| 360 |
+
"""Embed query using SentenceTransformers with retrieval prompt."""
|
| 361 |
+
self._check_embedder()
|
| 362 |
+
self.logger.debug(f'Embedding query using {self.embed_model_name}')
|
| 363 |
+
embeddings = self.embedding_model.encode(
|
| 364 |
+
[query_to_embed],
|
| 365 |
+
prompt='Given this prompt, retrieve relevant content\n Query:',
|
| 366 |
+
convert_to_numpy=True,
|
| 367 |
+
show_progress_bar=False
|
| 368 |
+
)
|
| 369 |
+
return embeddings[0].tolist() if hasattr(embeddings[0], 'tolist') else list(embeddings[0])
|
| 370 |
+
|
| 371 |
+
def embed_batch(self, texts_to_embed: list[str]) -> list[list]:
|
| 372 |
+
"""Embed multiple texts in a batch using SentenceTransformers with optimized settings."""
|
| 373 |
+
if not texts_to_embed:
|
| 374 |
+
return []
|
| 375 |
+
self._check_embedder()
|
| 376 |
+
self.logger.info(f'Batch embedding {len(texts_to_embed)} texts using {self.embed_model_name}')
|
| 377 |
+
embeddings = self.embedding_model.encode(
|
| 378 |
+
texts_to_embed,
|
| 379 |
+
batch_size=self.encode_batch_size,
|
| 380 |
+
convert_to_numpy=True,
|
| 381 |
+
show_progress_bar=len(texts_to_embed) > 100, # Only show progress for large batches
|
| 382 |
+
normalize_embeddings=True # Normalize for better similarity computation
|
| 383 |
+
)
|
| 384 |
+
return [emb.tolist() if hasattr(emb, 'tolist') else list(emb) for emb in embeddings]
|
| 385 |
+
|
| 386 |
+
def embed_chunk_code_batch(self, codes_to_embed: list[str]) -> list[list]:
|
| 387 |
+
"""Embed multiple code chunks in a batch using SentenceTransformers with optimized settings."""
|
| 388 |
+
if not codes_to_embed:
|
| 389 |
+
return []
|
| 390 |
+
self._check_embedder()
|
| 391 |
+
self.logger.info(f'Batch embedding {len(codes_to_embed)} code chunks using {self.embed_model_name}')
|
| 392 |
+
embeddings = self.embedding_model.encode(
|
| 393 |
+
codes_to_embed,
|
| 394 |
+
batch_size=self.encode_batch_size,
|
| 395 |
+
convert_to_numpy=True,
|
| 396 |
+
show_progress_bar=len(codes_to_embed) > 100, # Only show progress for large batches
|
| 397 |
+
normalize_embeddings=True # Normalize for better similarity computation
|
| 398 |
+
)
|
| 399 |
+
return [emb.tolist() if hasattr(emb, 'tolist') else list(emb) for emb in embeddings]
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def create_model_service(skip_embedder: bool = False, **kwargs) -> ModelServiceInterface:
|
| 403 |
+
"""
|
| 404 |
+
Factory function to create the appropriate ModelService based on embedder_type.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
skip_embedder (bool): If True, skip loading the embedding model (for keyword-only search).
|
| 408 |
+
**kwargs: Additional arguments including 'embedder_type' ('openai' or 'sentence-transformers')
|
| 409 |
+
and optional 'model_kwargs' dict which can override any env var defaults.
|
| 410 |
+
Returns:
|
| 411 |
+
ModelServiceInterface: An instance of the appropriate ModelService
|
| 412 |
+
"""
|
| 413 |
+
model_kwargs = kwargs.pop('model_kwargs', None)
|
| 414 |
+
embedder_type = kwargs.pop('embedder_type', 'openai')
|
| 415 |
+
|
| 416 |
+
if embedder_type == 'openai':
|
| 417 |
+
return OpenAIModelService(model_kwargs=model_kwargs, **kwargs)
|
| 418 |
+
elif embedder_type == 'sentence-transformers':
|
| 419 |
+
return SentenceTransformersModelService(model_kwargs=model_kwargs, skip_embedder=skip_embedder, **kwargs)
|
| 420 |
+
else:
|
| 421 |
+
logging.getLogger(LOGGER_NAME).warning(
|
| 422 |
+
f'Unknown embedder type: {embedder_type}, defaulting to OpenAI'
|
| 423 |
+
)
|
| 424 |
+
return OpenAIModelService(model_kwargs=model_kwargs, **kwargs)
|
RepoKnowledgeGraphLib/Node.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Dict, List
|
| 2 |
+
from dataclasses import dataclass, field, asdict
|
| 3 |
+
|
| 4 |
+
from .Entity import Entity
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class Node:
|
| 8 |
+
id: str = ''
|
| 9 |
+
name: str = ''
|
| 10 |
+
node_type: str = ''
|
| 11 |
+
description: Optional[str] = None
|
| 12 |
+
declared_entities: List[dict] = field(default_factory=list) # Classes, functions, variables
|
| 13 |
+
called_entities: List[str] = field( default_factory=list) # Classes, functions, variables, but also external libraries
|
| 14 |
+
|
| 15 |
+
def dict(self):
|
| 16 |
+
return {k: str(v) for k, v in asdict(self).items()}
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class DirectoryNode(Node):
|
| 20 |
+
path: str = ''
|
| 21 |
+
node_type: str = 'directory'
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class FileNode(Node):
|
| 26 |
+
path: str = ''
|
| 27 |
+
content: str = ''
|
| 28 |
+
node_type: str = 'file'
|
| 29 |
+
language : str = ''
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class ChunkNode(FileNode):
|
| 34 |
+
node_type: str = 'chunk'
|
| 35 |
+
order_in_file: int = field(default_factory=int)
|
| 36 |
+
embedding : list = None
|
| 37 |
+
|
| 38 |
+
def get_field_to_embed(self) -> Optional[str]:
|
| 39 |
+
# Use description if available, otherwise fall back to content
|
| 40 |
+
# This ensures we always have something meaningful to embed
|
| 41 |
+
if self.description and self.description.strip():
|
| 42 |
+
return self.description
|
| 43 |
+
return self.content
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class EntityNode(Node):
|
| 48 |
+
entity_type: str = ''
|
| 49 |
+
declaring_chunk_ids: List[str] = field(default_factory=list)
|
| 50 |
+
calling_chunk_ids: List[str] = field(default_factory=list)
|
| 51 |
+
aliases: List[str] = field(default_factory=list) # All possible aliases for this entity
|
| 52 |
+
node_type: str = 'entity'
|
| 53 |
+
|
| 54 |
+
def __post_init__(self):
|
| 55 |
+
# Use entity_name (stored in name field) as the id if id is not set
|
| 56 |
+
if not self.id and self.name:
|
| 57 |
+
self.id = self.name
|
| 58 |
+
|
| 59 |
+
def dict(self):
|
| 60 |
+
return {k: str(v) for k, v in asdict(self).items()}
|
| 61 |
+
|
| 62 |
+
def get_field_to_embed(self) -> Optional[str]:
|
| 63 |
+
return self.name
|
RepoKnowledgeGraphLib/QuestionMaker.py
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import asyncio
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
from .RepoKnowledgeGraph import RepoKnowledgeGraph
|
| 6 |
+
from .ModelService import ModelService
|
| 7 |
+
from .utils.logger_utils import setup_logger
|
| 8 |
+
from .utils.chunk_utils import organize_chunks_by_file_name, join_organized_chunks, extract_filename_from_chunk
|
| 9 |
+
from .Node import ChunkNode
|
| 10 |
+
|
| 11 |
+
LOGGER_NAME = "QUESTION_MAKER_LOGGER"
|
| 12 |
+
|
| 13 |
+
class QuestionMaker:
|
| 14 |
+
"""
|
| 15 |
+
The QuestionMaker class is responsible for generating code comprehension questions and answers
|
| 16 |
+
based on code chunks and knowledge graphs. It leverages a language model service to formulate
|
| 17 |
+
questions and answers that test deep understanding of code, focusing on mechanisms, design decisions,
|
| 18 |
+
and subtle behaviors. It supports generating questions for neighboring code chunks as well as for
|
| 19 |
+
specific entities (e.g., functions, classes) that are both declared and called in the codebase.
|
| 20 |
+
"""
|
| 21 |
+
def __init__(self):
|
| 22 |
+
"""
|
| 23 |
+
Initializes the QuestionMaker, sets up logging, and instantiates the model service.
|
| 24 |
+
"""
|
| 25 |
+
setup_logger(LOGGER_NAME)
|
| 26 |
+
self.logger = logging.getLogger(LOGGER_NAME)
|
| 27 |
+
self.model_service = ModelService()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def generate_questions_answers(self, candidate_chunks:dict) -> list:
|
| 31 |
+
"""
|
| 32 |
+
Placeholder for generating questions and answers from candidate chunks.
|
| 33 |
+
Args:
|
| 34 |
+
candidate_chunks (dict): Dictionary mapping chunk groups to process.
|
| 35 |
+
Returns:
|
| 36 |
+
list: List of question-answer pairs.
|
| 37 |
+
"""
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
def test_chunk_sensibility(self, knowledge_graph: RepoKnowledgeGraph) -> list:
|
| 41 |
+
"""
|
| 42 |
+
Placeholder for testing the sensibility of code chunks in the knowledge graph.
|
| 43 |
+
Args:
|
| 44 |
+
knowledge_graph (RepoKnowledgeGraph): The knowledge graph to test.
|
| 45 |
+
Returns:
|
| 46 |
+
list: List of results or metrics.
|
| 47 |
+
"""
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
async def make_n_neighbouring_chunk_questions_async(self, knowledge_graph: RepoKnowledgeGraph) -> list:
|
| 52 |
+
"""
|
| 53 |
+
Generates questions and answers for all possible groups of n directly neighboring code chunks
|
| 54 |
+
in each file of the knowledge graph. This helps assess understanding of code that spans multiple
|
| 55 |
+
adjacent chunks, such as related functions or code blocks.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
knowledge_graph (RepoKnowledgeGraph): The knowledge graph to generate questions from.
|
| 59 |
+
Returns:
|
| 60 |
+
list: A list of dictionaries, each containing a question, answer, the involved chunks, and category.
|
| 61 |
+
"""
|
| 62 |
+
file_nodes = knowledge_graph.get_all_files()
|
| 63 |
+
# create candidate chunks dictionary
|
| 64 |
+
candidate_chunks = []
|
| 65 |
+
for file_node in file_nodes:
|
| 66 |
+
self.logger.info(f"Processing file node: {file_node}")
|
| 67 |
+
chunks = knowledge_graph.get_chunks_of_file(file_node.id)
|
| 68 |
+
num_chunks = len(chunks)
|
| 69 |
+
# For each n, collect all n-sized tuples of directly neighbouring chunks
|
| 70 |
+
for n in range(2, num_chunks + 1):
|
| 71 |
+
for i in range(num_chunks - n + 1):
|
| 72 |
+
# Only directly neighbouring chunks
|
| 73 |
+
candidate_chunks.append(list(chunks[i:i+n]))
|
| 74 |
+
# generate questions and answers from candidate chunks in parallel, in batches of 15
|
| 75 |
+
|
| 76 |
+
async def process_chunk_group(chunks):
|
| 77 |
+
"""
|
| 78 |
+
Helper coroutine to generate a question and answer for a specific group of neighboring chunks.
|
| 79 |
+
Args:
|
| 80 |
+
chunks (list): The list of code chunks to generate the question and answer from.
|
| 81 |
+
Returns:
|
| 82 |
+
dict: Contains question, answer, chunks, and category.
|
| 83 |
+
"""
|
| 84 |
+
question = await self._generate_neighboring_question_from_chunks_async(chunks)
|
| 85 |
+
answer = await self.answer_question_about_chunks_async(chunks, question)
|
| 86 |
+
return {
|
| 87 |
+
'question': question,
|
| 88 |
+
'clean_question': question,
|
| 89 |
+
'answer': answer,
|
| 90 |
+
'chunks': [chunk.dict() for chunk in chunks],
|
| 91 |
+
'category': 'neighbouring_chunks'
|
| 92 |
+
}
|
| 93 |
+
# Batch processing in groups of 15 with tqdm
|
| 94 |
+
batch_size = 15
|
| 95 |
+
results = []
|
| 96 |
+
total = len(candidate_chunks)
|
| 97 |
+
for i in tqdm(range(0, total, batch_size), desc="Generating neighbouring chunk questions", unit="batch"):
|
| 98 |
+
batch = candidate_chunks[i:i+batch_size]
|
| 99 |
+
tasks = [process_chunk_group(chunks) for chunks in batch]
|
| 100 |
+
batch_results = []
|
| 101 |
+
for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Questions in batch", leave=False):
|
| 102 |
+
batch_results.append(await coro)
|
| 103 |
+
results.extend(batch_results)
|
| 104 |
+
return results
|
| 105 |
+
|
| 106 |
+
async def make_entity_declaration_call_specific_questions_async(self, knowledge_graph: RepoKnowledgeGraph) -> list:
|
| 107 |
+
"""
|
| 108 |
+
Generates questions and answers about specific entities (e.g., functions, classes) that have both
|
| 109 |
+
a declaration and at least one call site in the knowledge graph. Focuses on cross-file references
|
| 110 |
+
by default.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
knowledge_graph (RepoKnowledgeGraph): The knowledge graph to generate questions from.
|
| 114 |
+
Returns:
|
| 115 |
+
list: A list of dictionaries, each containing a question, answer, entity, involved chunks, and category.
|
| 116 |
+
"""
|
| 117 |
+
self.logger.info("Generating entity-specific questions.")
|
| 118 |
+
candidate_pairs = self.get_entities_with_declaration_and_calling(knowledge_graph)
|
| 119 |
+
|
| 120 |
+
async def process_entity_pair(pair):
|
| 121 |
+
"""
|
| 122 |
+
Helper coroutine to generate a question and answer for a specific entity's declaration and call site.
|
| 123 |
+
Args:
|
| 124 |
+
pair (dict): Contains entity name, declaring_chunk_id, and calling_chunk_id.
|
| 125 |
+
Returns:
|
| 126 |
+
dict: Contains question, answer, entity, chunks, and category.
|
| 127 |
+
"""
|
| 128 |
+
entity_name = pair['entity']
|
| 129 |
+
chunks = [knowledge_graph[pair['declaring_chunk_id']], knowledge_graph[pair['calling_chunk_id']]]
|
| 130 |
+
question = await self.make_entity_specific_question_async(chunks, entity_name)
|
| 131 |
+
answer = await self.answer_question_about_chunks_async(chunks, question)
|
| 132 |
+
return {
|
| 133 |
+
'question': question,
|
| 134 |
+
'clean_question': question,
|
| 135 |
+
'answer': answer,
|
| 136 |
+
'entity': entity_name,
|
| 137 |
+
'chunks': [chunk.dict() for chunk in chunks],
|
| 138 |
+
'category': 'entity_declaration_call_specific'
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
# Batch processing with tqdm
|
| 142 |
+
batch_size = 15
|
| 143 |
+
results = []
|
| 144 |
+
total = len(candidate_pairs)
|
| 145 |
+
for i in tqdm(range(0, total, batch_size), desc="Generating entity-specific questions", unit="batch"):
|
| 146 |
+
batch = candidate_pairs[i:i+batch_size]
|
| 147 |
+
tasks = [process_entity_pair(pair) for pair in batch]
|
| 148 |
+
batch_results = []
|
| 149 |
+
for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Questions in batch", leave=False):
|
| 150 |
+
batch_results.append(await coro)
|
| 151 |
+
results.extend(batch_results)
|
| 152 |
+
return results
|
| 153 |
+
|
| 154 |
+
async def make_interacting_entities_specific_questions_async(self, entity_A:str, entity_B:str,
|
| 155 |
+
decl_chunk_A: ChunkNode, decl_chunk_B: ChunkNode,
|
| 156 |
+
call_chunk: ChunkNode) -> str:
|
| 157 |
+
"""
|
| 158 |
+
Generates a question and answer about two entities that interact in the same chunk.
|
| 159 |
+
Each entity has a declaration and at least one call site, and the question focuses on their interaction.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
entity_A (str): Name of the first entity.
|
| 163 |
+
entity_B (str): Name of the second entity.
|
| 164 |
+
decl_chunk_A (str): Chunk of the declaration of entity A.
|
| 165 |
+
decl_chunk_B (str): Chunk of the declaration of entity B.
|
| 166 |
+
call_chunk (str): Chunk ID where both entities interact.
|
| 167 |
+
Returns:
|
| 168 |
+
str: the generated question as plain text.
|
| 169 |
+
"""
|
| 170 |
+
entity_A_definition_code = decl_chunk_A.content
|
| 171 |
+
entity_B_definition_code = decl_chunk_B.content
|
| 172 |
+
entity_interaction_code = call_chunk.content
|
| 173 |
+
|
| 174 |
+
prompt = f"""You are given two code entities, {entity_A} and {entity_B}, along with a snippet where they interact.
|
| 175 |
+
Your task is to write **one clear and concise question** about their relationship.
|
| 176 |
+
|
| 177 |
+
### Input:
|
| 178 |
+
* {entity_A} Definition Code:
|
| 179 |
+
{entity_A_definition_code}
|
| 180 |
+
|
| 181 |
+
* {entity_B} Definition Code:
|
| 182 |
+
{entity_B_definition_code}
|
| 183 |
+
|
| 184 |
+
* Interaction Code (where they interact):
|
| 185 |
+
{entity_interaction_code}
|
| 186 |
+
|
| 187 |
+
### Guidelines:
|
| 188 |
+
* Ask about design, abstraction, dependencies, or side effects.
|
| 189 |
+
* The question should highlight something a developer might consider when reviewing or improving the code.
|
| 190 |
+
* Keep the question short and direct so it can be answered briefly.
|
| 191 |
+
* Do not explain the code or provide answers.
|
| 192 |
+
|
| 193 |
+
### Output:
|
| 194 |
+
**Question**: <your question here>
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
initial_question = await self.model_service.query_async(prompt=prompt)
|
| 199 |
+
return await self.extract_question_from_generated_text_async(generated_text=initial_question)
|
| 200 |
+
|
| 201 |
+
def get_all_candidate_pairs_and_triplets(self, knowledge_graph: RepoKnowledgeGraph) -> list:
|
| 202 |
+
|
| 203 |
+
candidate_triplets = []
|
| 204 |
+
candidate_pairs = []
|
| 205 |
+
|
| 206 |
+
interacting_entity_triplets = self.get_interacting_entity_triplets(knowledge_graph)
|
| 207 |
+
for triplet in interacting_entity_triplets:
|
| 208 |
+
chunks = [
|
| 209 |
+
knowledge_graph[triplet['decl_chunk_A']],
|
| 210 |
+
knowledge_graph[triplet['decl_chunk_B']],
|
| 211 |
+
knowledge_graph[triplet['call_chunk']]
|
| 212 |
+
]
|
| 213 |
+
candidate_triplets.append({
|
| 214 |
+
'entities': (triplet['entity_A'], triplet['entity_B']),
|
| 215 |
+
'chunks': [chunk.dict() for chunk in chunks],
|
| 216 |
+
'category': 'interacting_entities'
|
| 217 |
+
})
|
| 218 |
+
|
| 219 |
+
declaration_calling_pairs = self.get_entities_with_declaration_and_calling(knowledge_graph)
|
| 220 |
+
for pair in declaration_calling_pairs:
|
| 221 |
+
chunks = [knowledge_graph[pair['declaring_chunk_id']], knowledge_graph[pair['calling_chunk_id']]]
|
| 222 |
+
candidate_pairs.append({
|
| 223 |
+
'entity': pair['entity'],
|
| 224 |
+
'chunks': [chunk.dict() for chunk in chunks],
|
| 225 |
+
'category': 'entity_declaration_call_specific'
|
| 226 |
+
})
|
| 227 |
+
|
| 228 |
+
return candidate_pairs, candidate_triplets
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
async def make_interacting_entity_questions_async(self, knowledge_graph: RepoKnowledgeGraph) -> list:
|
| 233 |
+
"""
|
| 234 |
+
Generates questions and answers about pairs of entities that interact in the same chunk.
|
| 235 |
+
Each entity has a declaration and at least one call site, and the question focuses on their interaction.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
knowledge_graph (RepoKnowledgeGraph): The knowledge graph to generate questions from.
|
| 239 |
+
Returns:
|
| 240 |
+
list: A list of dictionaries, each containing a question, answer, entities, involved chunks, and category.
|
| 241 |
+
"""
|
| 242 |
+
self.logger.info("Generating interacting entity questions.")
|
| 243 |
+
triplets = self.get_interacting_entity_triplets(knowledge_graph)
|
| 244 |
+
|
| 245 |
+
async def process_triplet(triplet):
|
| 246 |
+
"""
|
| 247 |
+
Helper coroutine to generate a question and answer for a specific interacting entity triplet.
|
| 248 |
+
Args:
|
| 249 |
+
triplet (dict): Contains entity_A, entity_B, decl_chunk_A, decl_chunk_B, and call_chunk.
|
| 250 |
+
Returns:
|
| 251 |
+
dict: Contains question, answer, entities, chunks, and category.
|
| 252 |
+
"""
|
| 253 |
+
chunks = [
|
| 254 |
+
knowledge_graph[triplet['decl_chunk_A']],
|
| 255 |
+
knowledge_graph[triplet['decl_chunk_B']],
|
| 256 |
+
knowledge_graph[triplet['call_chunk']]
|
| 257 |
+
]
|
| 258 |
+
question = await self.make_interacting_entities_specific_questions_async(entity_A=triplet['entity_A'],
|
| 259 |
+
entity_B=triplet['entity_B'],
|
| 260 |
+
decl_chunk_A=knowledge_graph[triplet['decl_chunk_A']],
|
| 261 |
+
decl_chunk_B=knowledge_graph[triplet['decl_chunk_B']],
|
| 262 |
+
call_chunk=knowledge_graph[triplet['call_chunk']])
|
| 263 |
+
answer = await self.answer_question_about_chunks_async(chunks, question)
|
| 264 |
+
return {
|
| 265 |
+
'question': question,
|
| 266 |
+
'clean_question': question,
|
| 267 |
+
'answer': answer,
|
| 268 |
+
'entities': (triplet['entity_A'], triplet['entity_B']),
|
| 269 |
+
'chunks': [chunk.dict() for chunk in chunks],
|
| 270 |
+
'category': 'interacting_entities'
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
# Batch processing with tqdm
|
| 274 |
+
batch_size = 15
|
| 275 |
+
results = []
|
| 276 |
+
total = len(triplets)
|
| 277 |
+
for i in tqdm(range(0, total, batch_size), desc="Generating interacting entity questions", unit="batch"):
|
| 278 |
+
batch = triplets[i:i+batch_size]
|
| 279 |
+
tasks = [process_triplet(triplet) for triplet in batch]
|
| 280 |
+
batch_results = []
|
| 281 |
+
for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Questions in batch", leave=False):
|
| 282 |
+
batch_results.append(await coro)
|
| 283 |
+
results.extend(batch_results)
|
| 284 |
+
return results
|
| 285 |
+
|
| 286 |
+
async def _generate_neighboring_question_from_chunks_async(self, chunks: list) -> str:
|
| 287 |
+
"""
|
| 288 |
+
Generates a single code comprehension question for a group of code chunks using the model service.
|
| 289 |
+
The question is designed to probe deep understanding of the code's mechanisms, design, or pitfalls.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
chunks (list): The list of code chunks to generate the question from.
|
| 293 |
+
Returns:
|
| 294 |
+
str: The generated question as plain text.
|
| 295 |
+
"""
|
| 296 |
+
organized_chunks = organize_chunks_by_file_name(chunks)
|
| 297 |
+
joined_chunks = join_organized_chunks(organized_chunks)
|
| 298 |
+
|
| 299 |
+
system_prompt = """You are an expert in evaluating code comprehension. The user will provide, in the next message, the content of a code submission (in any programming language). Your goal is to analyze this code, identify its critical, subtle, or obscure aspects, and generate **one relevant question in English** to ask someone in order to assess their understanding of the code.
|
| 300 |
+
|
| 301 |
+
This question should focus on:
|
| 302 |
+
|
| 303 |
+
* essential mechanisms of how the code works,
|
| 304 |
+
* important design decisions,
|
| 305 |
+
* potential pitfalls or unexpected behaviors,
|
| 306 |
+
* or any aspect that requires deep comprehension.
|
| 307 |
+
|
| 308 |
+
The goal is to test whether the person has **truly understood** the codeβnot just skimmed through it.
|
| 309 |
+
|
| 310 |
+
Respond with **only one question**, in plain text. Do not include any explanation, comment, or wrapper (e.g., no dictionaries, no lists).
|
| 311 |
+
"""
|
| 312 |
+
initial_question = await self.model_service.query_with_instructions_async(instructions=system_prompt, prompt=joined_chunks)
|
| 313 |
+
return await self.extract_question_from_generated_text_async(generated_text=initial_question)
|
| 314 |
+
|
| 315 |
+
async def answer_question_about_chunks_async(self, chunks: list, question: str) -> str:
|
| 316 |
+
"""
|
| 317 |
+
Generates an answer to a code comprehension question about a group of code chunks using the model service.
|
| 318 |
+
The answer should demonstrate deep understanding and cover mechanisms, design, and pitfalls.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
chunks (list): The list of code chunks to answer the question about.
|
| 322 |
+
question (str): The question to answer.
|
| 323 |
+
Returns:
|
| 324 |
+
str: The generated answer as plain text.
|
| 325 |
+
"""
|
| 326 |
+
organized_chunks = organize_chunks_by_file_name(chunks)
|
| 327 |
+
joined_chunks = join_organized_chunks(organized_chunks)
|
| 328 |
+
|
| 329 |
+
system_prompt = """You are an expert in evaluating code comprehension. The user will provide, in the next message, the content of a code submission (in any programming language) and a question about it. Your goal is to analyze this code, identify its critical, subtle, or obscure aspects, and generate **one relevant answer in English** to the question.
|
| 330 |
+
|
| 331 |
+
This answer should focus on:
|
| 332 |
+
|
| 333 |
+
* essential mechanisms of how the code works,
|
| 334 |
+
* important design decisions,
|
| 335 |
+
* potential pitfalls or unexpected behaviors,
|
| 336 |
+
* or any aspect that requires deep comprehension.
|
| 337 |
+
The goal is to provide a clear and thorough answer that demonstrates a deep understanding of the code.
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
return await self.model_service.query_with_instructions_async(instructions=system_prompt, prompt=joined_chunks + "\n\n" + question)
|
| 341 |
+
|
| 342 |
+
async def make_entity_specific_question_async(self, chunks: list, entity_name:str):
|
| 343 |
+
"""
|
| 344 |
+
Generates a question about a specific entity (e.g., function, class) in the context of the provided code chunks.
|
| 345 |
+
The question is designed to probe understanding of the entity's purpose, behavior, and interactions.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
chunks (list): The list of code chunks to generate the question from.
|
| 349 |
+
entity_name (str): The name of the entity to focus on.
|
| 350 |
+
Returns:
|
| 351 |
+
str: The generated question as plain text.
|
| 352 |
+
"""
|
| 353 |
+
organized_chunks = organize_chunks_by_file_name(chunks)
|
| 354 |
+
joined_chunks = join_organized_chunks(organized_chunks)
|
| 355 |
+
|
| 356 |
+
system_prompt = f"""You will be given one or more code snippets, possibly from multiple files.
|
| 357 |
+
|
| 358 |
+
A specific entity (such as a class, function, or variable) will be identified.
|
| 359 |
+
|
| 360 |
+
---
|
| 361 |
+
|
| 362 |
+
## Entity of Focus: {entity_name}
|
| 363 |
+
|
| 364 |
+
### Task:
|
| 365 |
+
* Write **one clear and concise question** about this entity.
|
| 366 |
+
* The question should highlight something a developer might consider, such as its purpose, behavior, interactions, or potential improvements.
|
| 367 |
+
|
| 368 |
+
### Guidelines:
|
| 369 |
+
* Keep the question short and direct.
|
| 370 |
+
* Do not explain the code or give an answer.
|
| 371 |
+
|
| 372 |
+
### Output:
|
| 373 |
+
**Question**: <your question here>
|
| 374 |
+
"""
|
| 375 |
+
|
| 376 |
+
initial_question= await self.model_service.query_with_instructions_async(instructions=system_prompt, prompt=joined_chunks)
|
| 377 |
+
return await self.extract_question_from_generated_text_async(generated_text=initial_question)
|
| 378 |
+
|
| 379 |
+
def get_entities_with_declaration_and_calling(self, knowledge_graph: RepoKnowledgeGraph, cross_file_only: bool = True) -> list:
|
| 380 |
+
"""
|
| 381 |
+
Finds all entities in the knowledge graph that have both a declaration and at least one call site.
|
| 382 |
+
Optionally restricts to cases where the declaration and call are in different files (cross-file).
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
knowledge_graph (RepoKnowledgeGraph): The knowledge graph to search in.
|
| 386 |
+
cross_file_only (bool): If True, only consider cross-file declaration/call pairs.
|
| 387 |
+
Returns:
|
| 388 |
+
list: List of dictionaries with 'entity', 'declaring_chunk_id', and 'calling_chunk_id'.
|
| 389 |
+
"""
|
| 390 |
+
candidate_pairs = []
|
| 391 |
+
entities = knowledge_graph.entities
|
| 392 |
+
for entity_name in entities:
|
| 393 |
+
entity = entities[entity_name]
|
| 394 |
+
if len(entity['declaring_chunk_ids']) and len(entity['calling_chunk_ids']):
|
| 395 |
+
found = False
|
| 396 |
+
for declaring_chunk_id in entity['declaring_chunk_ids']:
|
| 397 |
+
for calling_chunk_id in entity['calling_chunk_ids']:
|
| 398 |
+
if declaring_chunk_id != calling_chunk_id:
|
| 399 |
+
if cross_file_only and extract_filename_from_chunk(knowledge_graph[declaring_chunk_id]) == extract_filename_from_chunk(knowledge_graph[calling_chunk_id]):
|
| 400 |
+
continue
|
| 401 |
+
else:
|
| 402 |
+
candidate_pairs.append({'entity': entity_name, 'declaring_chunk_id' : declaring_chunk_id, 'calling_chunk_id': calling_chunk_id})
|
| 403 |
+
found = True
|
| 404 |
+
break
|
| 405 |
+
if found:
|
| 406 |
+
break
|
| 407 |
+
return candidate_pairs
|
| 408 |
+
|
| 409 |
+
def get_interacting_entity_triplets(self, knowledge_graph: RepoKnowledgeGraph) -> list:
|
| 410 |
+
"""
|
| 411 |
+
Finds triplets of chunk ids such that:
|
| 412 |
+
- Two entities (A, B) are interacting in the same chunk (call_chunk)
|
| 413 |
+
- Each entity has a declaring chunk (decl_chunk_A, decl_chunk_B)
|
| 414 |
+
- Both entities have non-empty declaring_chunk_ids and calling_chunk_ids
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
list of dicts with keys:
|
| 418 |
+
'entity_A', 'entity_B', 'decl_chunk_A', 'decl_chunk_B', 'call_chunk'
|
| 419 |
+
"""
|
| 420 |
+
triplets = []
|
| 421 |
+
seen_pairs = set()
|
| 422 |
+
entities = knowledge_graph.entities
|
| 423 |
+
for entity_A_name, entity_A in entities.items():
|
| 424 |
+
if not entity_A['declaring_chunk_ids'] or not entity_A['calling_chunk_ids']:
|
| 425 |
+
continue
|
| 426 |
+
for entity_B_name, entity_B in entities.items():
|
| 427 |
+
if entity_A_name == entity_B_name:
|
| 428 |
+
continue
|
| 429 |
+
if not entity_B['declaring_chunk_ids'] or not entity_B['calling_chunk_ids']:
|
| 430 |
+
continue
|
| 431 |
+
pair_key = (entity_A_name, entity_B_name)
|
| 432 |
+
if pair_key in seen_pairs:
|
| 433 |
+
continue
|
| 434 |
+
# Find intersection of calling_chunk_ids
|
| 435 |
+
call_chunks = set(entity_A['calling_chunk_ids']) & set(entity_B['calling_chunk_ids'])
|
| 436 |
+
found = False
|
| 437 |
+
for call_chunk in call_chunks:
|
| 438 |
+
for decl_chunk_A in entity_A['declaring_chunk_ids']:
|
| 439 |
+
for decl_chunk_B in entity_B['declaring_chunk_ids']:
|
| 440 |
+
triplets.append({
|
| 441 |
+
'entity_A': entity_A_name,
|
| 442 |
+
'entity_B': entity_B_name,
|
| 443 |
+
'decl_chunk_A': decl_chunk_A,
|
| 444 |
+
'decl_chunk_B': decl_chunk_B,
|
| 445 |
+
'call_chunk': call_chunk
|
| 446 |
+
})
|
| 447 |
+
seen_pairs.add(pair_key)
|
| 448 |
+
found = True
|
| 449 |
+
break
|
| 450 |
+
if found:
|
| 451 |
+
break
|
| 452 |
+
if found:
|
| 453 |
+
break
|
| 454 |
+
return triplets
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
async def extract_question_from_generated_text_async(self, generated_text: str) -> str:
|
| 458 |
+
"""
|
| 459 |
+
Extracts the question from the generated text. The question is expected to be the last line of the text.
|
| 460 |
+
|
| 461 |
+
Args:
|
| 462 |
+
generated_text (str): The text generated by the model.
|
| 463 |
+
Returns:
|
| 464 |
+
str: The extracted question.
|
| 465 |
+
"""
|
| 466 |
+
|
| 467 |
+
prompt = f"Extract only the question from the following text. Return the question exactly, with no extra words or labels:\n\n{generated_text}\n\n"
|
| 468 |
+
return await self.model_service.query_async(prompt=prompt)
|
| 469 |
+
|
| 470 |
+
def select_diverse_candidates(self, candidate_pairs, candidate_triplets, max_pairs=20, max_triplets=20):
|
| 471 |
+
"""
|
| 472 |
+
Selects a limited number of pairs and triplets with maximum diversity in entity representation.
|
| 473 |
+
Args:
|
| 474 |
+
candidate_pairs (list): List of candidate pairs (dicts with 'entity', ...).
|
| 475 |
+
candidate_triplets (list): List of candidate triplets (dicts with 'entities', ...).
|
| 476 |
+
max_pairs (int): Maximum number of pairs to select.
|
| 477 |
+
max_triplets (int): Maximum number of triplets to select.
|
| 478 |
+
Returns:
|
| 479 |
+
(list, list): Selected pairs and triplets.
|
| 480 |
+
"""
|
| 481 |
+
# Select pairs
|
| 482 |
+
selected_pairs = []
|
| 483 |
+
used_entities = set()
|
| 484 |
+
for pair in candidate_pairs:
|
| 485 |
+
entity = pair['entity']
|
| 486 |
+
if entity not in used_entities:
|
| 487 |
+
selected_pairs.append(pair)
|
| 488 |
+
used_entities.add(entity)
|
| 489 |
+
if len(selected_pairs) >= max_pairs:
|
| 490 |
+
break
|
| 491 |
+
# Select triplets
|
| 492 |
+
selected_triplets = []
|
| 493 |
+
used_entities_triplets = set()
|
| 494 |
+
for triplet in candidate_triplets:
|
| 495 |
+
entities = set(triplet['entities'])
|
| 496 |
+
if not entities & used_entities_triplets:
|
| 497 |
+
selected_triplets.append(triplet)
|
| 498 |
+
used_entities_triplets.update(entities)
|
| 499 |
+
if len(selected_triplets) >= max_triplets:
|
| 500 |
+
break
|
| 501 |
+
return selected_pairs, selected_triplets
|
| 502 |
+
|
| 503 |
+
async def transform_answser_into_mcq_answer_async(self, question, answer, chunks):
|
| 504 |
+
"""
|
| 505 |
+
Transforms the question and answer into a format suitable for MCQ generation.
|
| 506 |
+
"""
|
| 507 |
+
code = join_organized_chunks(organize_chunks_by_file_name(chunks))
|
| 508 |
+
|
| 509 |
+
prompt = f"""
|
| 510 |
+
You are an expert Python developer and technical writer. I will give you:
|
| 511 |
+
|
| 512 |
+
1. A Python code snippet
|
| 513 |
+
2. A question about that code
|
| 514 |
+
3. A detailed answer to the question
|
| 515 |
+
|
| 516 |
+
Your task is to **sanitize** the answer. That means:
|
| 517 |
+
|
| 518 |
+
- Strip away all fluff, filler, and redundant explanation
|
| 519 |
+
- Focus only on what directly answers the question
|
| 520 |
+
- Make it **short, clear, and direct**, as if it were a correct MCQ answer
|
| 521 |
+
- Prefer concise phrases or a single clear sentence over paragraph explanations
|
| 522 |
+
- Keep any necessary technical detail, but no more than needed
|
| 523 |
+
|
| 524 |
+
Do **not** repeat the question. Do **not** rephrase the code. Just give the concise, final answer.
|
| 525 |
+
|
| 526 |
+
- **Input Code**:
|
| 527 |
+
{code}
|
| 528 |
+
|
| 529 |
+
- **Question**:
|
| 530 |
+
{question}
|
| 531 |
+
|
| 532 |
+
- **Original Answer**:
|
| 533 |
+
{answer}
|
| 534 |
+
|
| 535 |
+
- **Sanitized Answer**:
|
| 536 |
+
"""
|
| 537 |
+
return await self.model_service.query_async(prompt)
|
| 538 |
+
|
RepoKnowledgeGraphLib/RepoKnowledgeGraph.py
ADDED
|
@@ -0,0 +1,1608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import networkx as nx
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import asyncio
|
| 5 |
+
import nest_asyncio
|
| 6 |
+
import tqdm
|
| 7 |
+
# from pathlib import Path
|
| 8 |
+
import os.path
|
| 9 |
+
import tempfile
|
| 10 |
+
import subprocess
|
| 11 |
+
from typing import List, Optional, Dict
|
| 12 |
+
import logging
|
| 13 |
+
import urllib.parse
|
| 14 |
+
|
| 15 |
+
from .ModelService import create_model_service
|
| 16 |
+
from .Node import Node, DirectoryNode, FileNode, ChunkNode, EntityNode
|
| 17 |
+
from .CodeParser import CodeParser
|
| 18 |
+
from .EntityExtractor import HybridEntityExtractor
|
| 19 |
+
from .CodeIndex import CodeIndex
|
| 20 |
+
from .utils.logger_utils import setup_logger
|
| 21 |
+
from .utils.parsing_utils import read_directory_files_recursively, get_language_from_filename
|
| 22 |
+
from .utils.path_utils import prepare_input_path, build_entity_alias_map, resolve_entity_call
|
| 23 |
+
from .EntityChunkMapper import EntityChunkMapper
|
| 24 |
+
|
| 25 |
+
LOGGER_NAME = 'REPO_KNOWLEDGE_GRAPH_LOGGER'
|
| 26 |
+
|
| 27 |
+
MODEL_SERVICE_TYPES = ['openai', 'sentence-transformers']
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# A RepoKnowledgeGraph is a weighted DAG based on a tree-structure with added edges
|
| 31 |
+
class RepoKnowledgeGraph:
|
| 32 |
+
"""
|
| 33 |
+
RepoKnowledgeGraph builds a knowledge graph of a code repository.
|
| 34 |
+
It parses source files, extracts code entities and relationships, and organizes them
|
| 35 |
+
into a directed acyclic graph (DAG) with additional semantic edges.
|
| 36 |
+
|
| 37 |
+
Use `from_path()` or `load_graph_from_file()` to create instances.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self):
|
| 41 |
+
"""
|
| 42 |
+
Private constructor. Use from_path() or load_graph_from_file() instead.
|
| 43 |
+
"""
|
| 44 |
+
raise RuntimeError(
|
| 45 |
+
"Cannot instantiate RepoKnowledgeGraph directly. "
|
| 46 |
+
"Use RepoKnowledgeGraph.from_path() or RepoKnowledgeGraph.load_graph_from_file() instead."
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def _initialize(self, model_service_kwargs: dict, code_index_kwargs: Optional[dict] = None):
|
| 50 |
+
"""Internal initialization method."""
|
| 51 |
+
setup_logger(LOGGER_NAME)
|
| 52 |
+
self.logger = logging.getLogger(LOGGER_NAME)
|
| 53 |
+
self.logger.info('Initializing RepoKnowledgeGraph instance.')
|
| 54 |
+
self.code_parser = CodeParser()
|
| 55 |
+
|
| 56 |
+
# Determine if we should skip loading the embedder based on index_type
|
| 57 |
+
index_type = (code_index_kwargs or {}).get('index_type', 'hybrid')
|
| 58 |
+
skip_embedder = index_type == 'keyword-only'
|
| 59 |
+
if skip_embedder:
|
| 60 |
+
self.logger.info('Using keyword-only index, skipping embedder initialization')
|
| 61 |
+
|
| 62 |
+
self.model_service = create_model_service(skip_embedder=skip_embedder, **model_service_kwargs)
|
| 63 |
+
self.entities = {}
|
| 64 |
+
self.graph = nx.DiGraph()
|
| 65 |
+
self.knowledge_graph = nx.DiGraph()
|
| 66 |
+
self.code_index = None
|
| 67 |
+
self.entity_extractor = HybridEntityExtractor()
|
| 68 |
+
|
| 69 |
+
def __iter__(self):
|
| 70 |
+
# Yield only the 'data' attribute from each node
|
| 71 |
+
return (node_data['data'] for _, node_data in self.graph.nodes(data=True))
|
| 72 |
+
|
| 73 |
+
def __getitem__(self, node_id):
|
| 74 |
+
return self.graph.nodes[node_id]['data']
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@classmethod
|
| 78 |
+
def from_path(cls, path: str, skip_dirs: Optional[list] = None, index_nodes: bool = True, describe_nodes=False,
|
| 79 |
+
extract_entities: bool = False, model_service_kwargs: Optional[dict] = None, code_index_kwargs: Optional[dict] = None):
|
| 80 |
+
if skip_dirs is None:
|
| 81 |
+
skip_dirs = []
|
| 82 |
+
if model_service_kwargs is None:
|
| 83 |
+
model_service_kwargs = {}
|
| 84 |
+
"""
|
| 85 |
+
Alternative constructor to build a RepoKnowledgeGraph from a path, with options to skip directories
|
| 86 |
+
and control entity extraction and node description.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
path (str): Path to the root of the code repository.
|
| 90 |
+
skip_dirs (list): List of directory names to skip.
|
| 91 |
+
index_nodes (bool): Whether to build a code index.
|
| 92 |
+
describe_nodes (bool): Whether to generate descriptions for code chunks.
|
| 93 |
+
extract_entities (bool): Whether to extract entities from code.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
RepoKnowledgeGraph: The constructed knowledge graph.
|
| 97 |
+
"""
|
| 98 |
+
instance = cls.__new__(cls) # Create instance without calling __init__
|
| 99 |
+
instance._initialize(model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs)
|
| 100 |
+
|
| 101 |
+
instance.logger.info(f"Preparing to build knowledge graph from path: {path}")
|
| 102 |
+
|
| 103 |
+
prepared_path = prepare_input_path(path)
|
| 104 |
+
instance.logger.debug(f"Prepared input path: {prepared_path}")
|
| 105 |
+
|
| 106 |
+
# Handle running event loop (e.g., in Jupyter)
|
| 107 |
+
try:
|
| 108 |
+
loop = asyncio.get_running_loop()
|
| 109 |
+
except RuntimeError:
|
| 110 |
+
loop = None
|
| 111 |
+
|
| 112 |
+
if loop and loop.is_running():
|
| 113 |
+
instance.logger.debug("Detected running event loop, applying nest_asyncio.")
|
| 114 |
+
nest_asyncio.apply()
|
| 115 |
+
task = instance._initial_parse_path_async(prepared_path, skip_dirs=skip_dirs, index_nodes=index_nodes,
|
| 116 |
+
describe_nodes=describe_nodes, extract_entities=extract_entities)
|
| 117 |
+
loop.run_until_complete(task)
|
| 118 |
+
else:
|
| 119 |
+
instance.logger.debug("No running event loop, using asyncio.run.")
|
| 120 |
+
asyncio.run(instance._initial_parse_path_async(prepared_path, skip_dirs=skip_dirs, index_nodes=index_nodes,
|
| 121 |
+
describe_nodes=describe_nodes,
|
| 122 |
+
extract_entities=extract_entities))
|
| 123 |
+
|
| 124 |
+
instance.logger.info("Parsing files and building initial nodes...")
|
| 125 |
+
instance.logger.info("Initial parse and node creation complete. Building relationships between nodes...")
|
| 126 |
+
instance._build_relationships()
|
| 127 |
+
|
| 128 |
+
if index_nodes:
|
| 129 |
+
instance.logger.info("Building code index for all nodes in the graph...")
|
| 130 |
+
instance.code_index = CodeIndex(list(instance), model_service=instance.model_service, **(code_index_kwargs or {}))
|
| 131 |
+
|
| 132 |
+
instance.logger.info("Knowledge graph construction from path completed successfully.")
|
| 133 |
+
return instance
|
| 134 |
+
|
| 135 |
+
@classmethod
|
| 136 |
+
def from_repo(
|
| 137 |
+
cls,
|
| 138 |
+
repo_url: str,
|
| 139 |
+
skip_dirs: Optional[list] = None,
|
| 140 |
+
index_nodes: bool = True,
|
| 141 |
+
describe_nodes: bool = False,
|
| 142 |
+
extract_entities: bool = False,
|
| 143 |
+
model_service_kwargs: Optional[dict] = None,
|
| 144 |
+
code_index_kwargs: Optional[dict]=None,
|
| 145 |
+
github_token: Optional[str] = None,
|
| 146 |
+
allow_unauthenticated_clone: bool = True,
|
| 147 |
+
):
|
| 148 |
+
"""
|
| 149 |
+
Alternative constructor to build a RepoKnowledgeGraph from a remote git repository URL.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
repo_url (str): Git repository URL (SSH or HTTPS).
|
| 153 |
+
skip_dirs (list): List of directory names to skip.
|
| 154 |
+
index_nodes (bool): Whether to build a code index.
|
| 155 |
+
describe_nodes (bool): Whether to generate descriptions for code chunks.
|
| 156 |
+
extract_entities (bool): Whether to extract entities from code.
|
| 157 |
+
github_token (str, optional): Personal access token to access private GitHub repos.
|
| 158 |
+
If not provided, the method will look for the `GITHUB_OAUTH_TOKEN` environment variable.
|
| 159 |
+
allow_unauthenticated_clone (bool): If True, attempt to clone without a token when none is provided.
|
| 160 |
+
If False, raise an error when no token is available.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
RepoKnowledgeGraph: The constructed knowledge graph.
|
| 164 |
+
"""
|
| 165 |
+
if skip_dirs is None:
|
| 166 |
+
skip_dirs = []
|
| 167 |
+
if model_service_kwargs is None:
|
| 168 |
+
model_service_kwargs = {}
|
| 169 |
+
|
| 170 |
+
instance = cls.__new__(cls)
|
| 171 |
+
instance._initialize(model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs)
|
| 172 |
+
|
| 173 |
+
instance.logger.info(f"Starting knowledge graph build from remote repository: {repo_url}")
|
| 174 |
+
|
| 175 |
+
# Determine token
|
| 176 |
+
token = github_token or os.environ.get('GITHUB_OAUTH_TOKEN')
|
| 177 |
+
|
| 178 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 179 |
+
clone_url = repo_url
|
| 180 |
+
try:
|
| 181 |
+
if repo_url.startswith('git@'):
|
| 182 |
+
# Convert git@github.com:owner/repo.git -> https://github.com/owner/repo.git
|
| 183 |
+
clone_url = repo_url.replace(':', '/').split('git@')[-1]
|
| 184 |
+
clone_url = f'https://{clone_url}'
|
| 185 |
+
|
| 186 |
+
if token and clone_url.startswith('https://'):
|
| 187 |
+
encoded_token = urllib.parse.quote(token, safe='')
|
| 188 |
+
clone_url = clone_url.replace('https://', f'https://{encoded_token}@')
|
| 189 |
+
elif not token and not allow_unauthenticated_clone:
|
| 190 |
+
raise ValueError(
|
| 191 |
+
"GitHub token not provided and unauthenticated clone is disabled. "
|
| 192 |
+
"Set allow_unauthenticated_clone=True or provide a token."
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
instance.logger.debug(f"Running git clone: {clone_url} -> {tmpdirname}")
|
| 196 |
+
subprocess.run(['git', 'clone', clone_url, tmpdirname], check=True)
|
| 197 |
+
|
| 198 |
+
except Exception as e:
|
| 199 |
+
instance.logger.error(f"Failed to clone repository {repo_url} using URL {clone_url}: {e}")
|
| 200 |
+
raise
|
| 201 |
+
|
| 202 |
+
instance.logger.info(f"Repository successfully cloned to: {tmpdirname}")
|
| 203 |
+
|
| 204 |
+
return cls.from_path(
|
| 205 |
+
tmpdirname,
|
| 206 |
+
skip_dirs=skip_dirs,
|
| 207 |
+
index_nodes=index_nodes,
|
| 208 |
+
describe_nodes=describe_nodes,
|
| 209 |
+
extract_entities=extract_entities,
|
| 210 |
+
model_service_kwargs=model_service_kwargs,
|
| 211 |
+
code_index_kwargs=code_index_kwargs
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
async def _initial_parse_path_async(self, path: str, skip_dirs: list, index_nodes=True, describe_nodes=True,
|
| 215 |
+
extract_entities: bool = True):
|
| 216 |
+
self.logger.info(f"Beginning async parsing of repository at path: {path}")
|
| 217 |
+
"""
|
| 218 |
+
Orchestrates the parsing and graph construction process:
|
| 219 |
+
1. Reads files and splits into chunks.
|
| 220 |
+
2. Extracts entities and relationships.
|
| 221 |
+
3. Builds chunk, file, directory, and root nodes.
|
| 222 |
+
4. Aggregates entity information.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
path (str): Root path to parse.
|
| 226 |
+
skip_dirs (list): Directories to skip.
|
| 227 |
+
index_nodes (bool): Whether to build code index.
|
| 228 |
+
describe_nodes (bool): Whether to generate descriptions.
|
| 229 |
+
extract_entities (bool): Whether to extract entities.
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
# --- Pass 1: Create ChunkNodes ---
|
| 233 |
+
level1_node_contents = read_directory_files_recursively(
|
| 234 |
+
path, skip_dirs=skip_dirs,
|
| 235 |
+
skip_pattern=r"(?:\.log$|\.json$|(?:^|/)(?:\.git|\.idea|__pycache__|\.cache)(?:/|$)|(?:^|/)(?:changelog|ChangeLog)(?:\.[a-z0-9]+)?$|\.cache$)"
|
| 236 |
+
)
|
| 237 |
+
self.logger.debug(f"Found {len(level1_node_contents)} files to process.")
|
| 238 |
+
self.logger.info("Chunk nodes creation step started.")
|
| 239 |
+
chunk_info = await self._create_chunk_nodes(
|
| 240 |
+
level1_node_contents, extract_entities, describe_nodes, index_nodes, root_path=path
|
| 241 |
+
)
|
| 242 |
+
self.logger.info("Chunk nodes creation step finished.")
|
| 243 |
+
self.logger.info("File nodes creation step started.")
|
| 244 |
+
file_info = self._create_file_nodes(
|
| 245 |
+
chunk_info, level1_node_contents
|
| 246 |
+
)
|
| 247 |
+
self.logger.info("File nodes creation step finished.")
|
| 248 |
+
self.logger.info("Directory nodes creation step started.")
|
| 249 |
+
dir_agg = self._create_directory_nodes(
|
| 250 |
+
file_info
|
| 251 |
+
)
|
| 252 |
+
self.logger.info("Directory nodes creation step finished.")
|
| 253 |
+
self.logger.info("Aggregating all nodes to root node.")
|
| 254 |
+
self._aggregate_to_root(dir_agg)
|
| 255 |
+
self.logger.info("Async parse and node aggregation fully complete.")
|
| 256 |
+
|
| 257 |
+
async def _create_chunk_nodes(self, level1_node_contents, extract_entities, describe_nodes, index_nodes, root_path=None):
|
| 258 |
+
self.logger.info(f"Starting chunk node creation for {len(level1_node_contents)} files.")
|
| 259 |
+
accepted_extensions = {'.py', '.c', '.cpp', '.h', '.hpp', '.java', '.js', '.ts', '.jsx', '.tsx', '.rs', '.html'}
|
| 260 |
+
chunk_info = {}
|
| 261 |
+
entity_mapper = EntityChunkMapper()
|
| 262 |
+
total_chunks = 0
|
| 263 |
+
|
| 264 |
+
# Use tqdm for progress bar over files
|
| 265 |
+
for file_path in tqdm.tqdm(level1_node_contents, desc="Processing files for chunk nodes"):
|
| 266 |
+
self.logger.debug(f"Processing file for chunk nodes: {file_path}")
|
| 267 |
+
full_path = os.path.normpath(file_path)
|
| 268 |
+
parts = full_path.split(os.sep)
|
| 269 |
+
_, ext = os.path.splitext(file_path)
|
| 270 |
+
is_code_file = ext.lower() in accepted_extensions
|
| 271 |
+
|
| 272 |
+
self.logger.debug(f"Parsing file: {file_path}")
|
| 273 |
+
|
| 274 |
+
# Parse file into chunks
|
| 275 |
+
parsed_content = self.code_parser.parse(file_name=file_path, file_content=level1_node_contents[file_path])
|
| 276 |
+
self.logger.debug(f"Parsed {len(parsed_content)} chunks from file: {file_path}")
|
| 277 |
+
total_chunks += len(parsed_content)
|
| 278 |
+
|
| 279 |
+
# Entity extraction logging
|
| 280 |
+
if extract_entities and is_code_file:
|
| 281 |
+
self.logger.debug(f"Extracting entities from code file: {file_path}")
|
| 282 |
+
try:
|
| 283 |
+
# Construct full path for entity extraction (needed for C/C++ include resolution)
|
| 284 |
+
extraction_file_path = os.path.join(root_path, file_path) if root_path else file_path
|
| 285 |
+
|
| 286 |
+
file_declared_entities, file_called_entities = self.entity_extractor.extract_entities(
|
| 287 |
+
code=level1_node_contents[file_path], file_name=extraction_file_path)
|
| 288 |
+
self.logger.debug(f"Extracted {len(file_declared_entities)} declared and {len(file_called_entities)} called entities from file: {file_path}")
|
| 289 |
+
|
| 290 |
+
chunk_declared_map, chunk_called_map = entity_mapper.map_entities_to_chunks(
|
| 291 |
+
file_declared_entities, file_called_entities, parsed_content, file_name=file_path)
|
| 292 |
+
self.logger.debug(f"Mapped entities to {len(parsed_content)} chunks for file: {file_path}")
|
| 293 |
+
except Exception as e:
|
| 294 |
+
self.logger.error(f"Error extracting entities from {file_path}: {e}")
|
| 295 |
+
file_declared_entities, file_called_entities = [], []
|
| 296 |
+
chunk_declared_map = {i: [] for i in range(len(parsed_content))}
|
| 297 |
+
chunk_called_map = {i: [] for i in range(len(parsed_content))}
|
| 298 |
+
else:
|
| 299 |
+
self.logger.debug(f"Skipping entity extraction for non-code file: {file_path}")
|
| 300 |
+
file_declared_entities, file_called_entities = [], []
|
| 301 |
+
chunk_declared_map = {i: [] for i in range(len(parsed_content))}
|
| 302 |
+
chunk_called_map = {i: [] for i in range(len(parsed_content))}
|
| 303 |
+
|
| 304 |
+
chunk_tasks = []
|
| 305 |
+
for i, chunk in enumerate(parsed_content):
|
| 306 |
+
chunk_id = f'{file_path}_{i}'
|
| 307 |
+
self.logger.debug(f"Scheduling processing for chunk {chunk_id} of file {file_path}")
|
| 308 |
+
|
| 309 |
+
async def process_chunk(i=i, chunk=chunk, chunk_id=chunk_id):
|
| 310 |
+
self.logger.debug(f"Creating chunk node: {chunk_id}")
|
| 311 |
+
declared_entities = chunk_declared_map.get(i, [])
|
| 312 |
+
called_entities = chunk_called_map.get(i, [])
|
| 313 |
+
|
| 314 |
+
# FIRST PASS: Register all declared entities with aliases
|
| 315 |
+
# Build temporary alias map for checking existing entities
|
| 316 |
+
temp_alias_map = build_entity_alias_map(self.entities)
|
| 317 |
+
|
| 318 |
+
for entity in declared_entities:
|
| 319 |
+
name = entity.get("name")
|
| 320 |
+
if not name:
|
| 321 |
+
continue
|
| 322 |
+
|
| 323 |
+
# Check if this entity already exists under any of its aliases
|
| 324 |
+
entity_aliases = entity.get("aliases", [])
|
| 325 |
+
canonical_name = None
|
| 326 |
+
|
| 327 |
+
# First check if the name itself already exists or is an alias
|
| 328 |
+
if name in temp_alias_map:
|
| 329 |
+
canonical_name = temp_alias_map[name]
|
| 330 |
+
self.logger.debug(f"Entity '{name}' already exists as '{canonical_name}'")
|
| 331 |
+
else:
|
| 332 |
+
# Check if any of the entity's aliases match existing entities
|
| 333 |
+
for alias in entity_aliases:
|
| 334 |
+
if alias in temp_alias_map:
|
| 335 |
+
canonical_name = temp_alias_map[alias]
|
| 336 |
+
self.logger.debug(f"Entity '{name}' matches existing entity '{canonical_name}' via alias '{alias}'")
|
| 337 |
+
break
|
| 338 |
+
|
| 339 |
+
# If we found a match, use the canonical name; otherwise use the entity name
|
| 340 |
+
if canonical_name:
|
| 341 |
+
entity_key = canonical_name
|
| 342 |
+
else:
|
| 343 |
+
entity_key = name
|
| 344 |
+
self.logger.debug(f"Registering new declared entity '{name}' in chunk {chunk_id}")
|
| 345 |
+
self.entities[entity_key] = {
|
| 346 |
+
"declaring_chunk_ids": [],
|
| 347 |
+
"calling_chunk_ids": [],
|
| 348 |
+
"type": [],
|
| 349 |
+
"dtype": None,
|
| 350 |
+
"aliases": []
|
| 351 |
+
}
|
| 352 |
+
# Update temp alias map with new entity
|
| 353 |
+
temp_alias_map[entity_key] = entity_key
|
| 354 |
+
|
| 355 |
+
if chunk_id not in self.entities[entity_key]["declaring_chunk_ids"]:
|
| 356 |
+
self.entities[entity_key]["declaring_chunk_ids"].append(chunk_id)
|
| 357 |
+
entity_type = entity.get("type")
|
| 358 |
+
if entity_type and entity_type not in self.entities[entity_key]["type"]:
|
| 359 |
+
self.entities[entity_key]["type"].append(entity_type)
|
| 360 |
+
dtype = entity.get("dtype")
|
| 361 |
+
if dtype:
|
| 362 |
+
self.entities[entity_key]["dtype"] = dtype
|
| 363 |
+
# Store aliases (add new ones, avoiding duplicates)
|
| 364 |
+
for alias in [name] + entity_aliases:
|
| 365 |
+
if alias and alias not in self.entities[entity_key]["aliases"]:
|
| 366 |
+
self.entities[entity_key]["aliases"].append(alias)
|
| 367 |
+
temp_alias_map[alias] = entity_key # Update temp map
|
| 368 |
+
self.logger.debug(f"Declared entity '{name}' registered as '{entity_key}' in chunk {chunk_id} with aliases: {self.entities[entity_key]['aliases']}")
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
# Logging for node creation
|
| 372 |
+
if describe_nodes:
|
| 373 |
+
self.logger.info(f"Generating description for chunk {chunk_id}")
|
| 374 |
+
try:
|
| 375 |
+
description = await self.model_service.query_async(
|
| 376 |
+
f'Summarize this {get_language_from_filename(file_path)} code chunk in a few sentences: {chunk}')
|
| 377 |
+
except Exception as e:
|
| 378 |
+
self.logger.error(f"Error generating description for chunk {chunk_id}: {e}")
|
| 379 |
+
description = ''
|
| 380 |
+
else:
|
| 381 |
+
self.logger.debug(f"No description requested for chunk {chunk_id}")
|
| 382 |
+
description = ''
|
| 383 |
+
|
| 384 |
+
chunk_node = ChunkNode(
|
| 385 |
+
id=chunk_id,
|
| 386 |
+
name=chunk_id,
|
| 387 |
+
path=file_path,
|
| 388 |
+
content=chunk,
|
| 389 |
+
order_in_file=i,
|
| 390 |
+
called_entities=called_entities,
|
| 391 |
+
declared_entities=declared_entities,
|
| 392 |
+
language=get_language_from_filename(file_path),
|
| 393 |
+
description=description,
|
| 394 |
+
)
|
| 395 |
+
self.logger.debug(f"Chunk node created: {chunk_id}")
|
| 396 |
+
|
| 397 |
+
# NOTE: Embeddings are now deferred to CodeIndex for efficient batch processing
|
| 398 |
+
# This avoids the slow one-at-a-time embedding during chunk creation
|
| 399 |
+
chunk_node.embedding = None
|
| 400 |
+
return (chunk_id, chunk_node, declared_entities, called_entities)
|
| 401 |
+
|
| 402 |
+
chunk_tasks.append(process_chunk())
|
| 403 |
+
|
| 404 |
+
chunk_results = await asyncio.gather(*chunk_tasks)
|
| 405 |
+
self.logger.debug(f"Finished processing {len(chunk_results)} chunks for file {file_path}.")
|
| 406 |
+
chunk_info[file_path] = {
|
| 407 |
+
'chunk_results': chunk_results,
|
| 408 |
+
'file_declared_entities': file_declared_entities,
|
| 409 |
+
'file_called_entities': file_called_entities
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
# Log summary
|
| 413 |
+
self.logger.info(f"Created {total_chunks} chunk nodes from {len(level1_node_contents)} files")
|
| 414 |
+
|
| 415 |
+
# SECOND PASS: Now that all declared entities are registered, resolve called entities
|
| 416 |
+
self.logger.info("Starting second pass: resolving called entities using alias map...")
|
| 417 |
+
alias_map = build_entity_alias_map(self.entities)
|
| 418 |
+
self.logger.info(f"Built alias map with {len(alias_map)} entries for resolution")
|
| 419 |
+
|
| 420 |
+
resolved_count = 0
|
| 421 |
+
for file_path, file_data in tqdm.tqdm(chunk_info.items(), desc="Resolving called entities"):
|
| 422 |
+
chunk_results = file_data['chunk_results']
|
| 423 |
+
for chunk_id, chunk_node, declared_entities, called_entities in chunk_results:
|
| 424 |
+
for called_name in called_entities:
|
| 425 |
+
# Skip empty or whitespace-only names
|
| 426 |
+
if not called_name or not called_name.strip():
|
| 427 |
+
continue
|
| 428 |
+
|
| 429 |
+
# Try to resolve this called entity to an existing declared entity using aliases
|
| 430 |
+
resolved_name = resolve_entity_call(called_name, alias_map)
|
| 431 |
+
|
| 432 |
+
# Use the resolved name if found, otherwise check if called_name is already an alias
|
| 433 |
+
if resolved_name:
|
| 434 |
+
entity_key = resolved_name
|
| 435 |
+
elif called_name in alias_map:
|
| 436 |
+
# The called_name itself is an alias of an existing entity
|
| 437 |
+
entity_key = alias_map[called_name]
|
| 438 |
+
else:
|
| 439 |
+
# No match found, use the original called name
|
| 440 |
+
entity_key = called_name
|
| 441 |
+
|
| 442 |
+
if entity_key not in self.entities:
|
| 443 |
+
self.logger.debug(f"Registering new called entity '{entity_key}' (called as '{called_name}') in chunk {chunk_id}")
|
| 444 |
+
self.entities[entity_key] = {
|
| 445 |
+
"declaring_chunk_ids": [],
|
| 446 |
+
"calling_chunk_ids": [],
|
| 447 |
+
"type": [],
|
| 448 |
+
"dtype": None,
|
| 449 |
+
"aliases": []
|
| 450 |
+
}
|
| 451 |
+
# Add called_name as an alias if it's different from entity_key
|
| 452 |
+
if called_name != entity_key:
|
| 453 |
+
self.entities[entity_key]["aliases"].append(called_name)
|
| 454 |
+
alias_map[called_name] = entity_key # Update alias map
|
| 455 |
+
|
| 456 |
+
if chunk_id not in self.entities[entity_key]["calling_chunk_ids"]:
|
| 457 |
+
self.entities[entity_key]["calling_chunk_ids"].append(chunk_id)
|
| 458 |
+
|
| 459 |
+
if resolved_name and resolved_name != called_name:
|
| 460 |
+
resolved_count += 1
|
| 461 |
+
self.logger.debug(f"Called entity '{called_name}' resolved to '{entity_key}' in chunk {chunk_id}")
|
| 462 |
+
|
| 463 |
+
self.logger.info(f"Resolved {resolved_count} entity calls to existing declarations via aliases")
|
| 464 |
+
self.logger.info("All chunk nodes have been created for all files.")
|
| 465 |
+
return chunk_info
|
| 466 |
+
|
| 467 |
+
def _create_file_nodes(self, chunk_info, level1_node_contents):
|
| 468 |
+
self.logger.info("Starting file node creation.")
|
| 469 |
+
"""
|
| 470 |
+
For each file, aggregate chunk information and create FileNode objects.
|
| 471 |
+
This method remains mostly the same.
|
| 472 |
+
"""
|
| 473 |
+
|
| 474 |
+
def merge_entities(target, source):
|
| 475 |
+
# Merge entity lists, avoiding duplicates by (name, type)
|
| 476 |
+
existing = set((e.get('name'), e.get('type')) for e in target)
|
| 477 |
+
for e in source:
|
| 478 |
+
k = (e.get('name'), e.get('type'))
|
| 479 |
+
if k not in existing:
|
| 480 |
+
target.append(e)
|
| 481 |
+
existing.add(k)
|
| 482 |
+
|
| 483 |
+
def merge_called_entities(target, source):
|
| 484 |
+
# Merge called entity lists, avoiding duplicates
|
| 485 |
+
existing = set(target)
|
| 486 |
+
for e in source:
|
| 487 |
+
if e not in existing:
|
| 488 |
+
target.append(e)
|
| 489 |
+
existing.add(e)
|
| 490 |
+
|
| 491 |
+
file_info = {}
|
| 492 |
+
for file_path, file_data in tqdm.tqdm(chunk_info.items(), desc="Creating file nodes"):
|
| 493 |
+
self.logger.info(f"Creating file node for: {file_path}")
|
| 494 |
+
parts = os.path.normpath(file_path).split(os.sep)
|
| 495 |
+
|
| 496 |
+
# Extract file-level entities and chunk results from the stored data
|
| 497 |
+
chunk_results = file_data['chunk_results']
|
| 498 |
+
file_declared_entities = list(file_data['file_declared_entities']) # Use file-level entities directly
|
| 499 |
+
file_called_entities = list(file_data['file_called_entities']) # Use file-level entities directly
|
| 500 |
+
chunk_ids = []
|
| 501 |
+
|
| 502 |
+
for chunk_id, chunk_node, declared_entities, called_entities in chunk_results:
|
| 503 |
+
self.logger.info(f"Adding chunk node {chunk_id} to graph for file {file_path}")
|
| 504 |
+
self.graph.add_node(chunk_id, data=chunk_node, level=2)
|
| 505 |
+
chunk_ids.append(chunk_id)
|
| 506 |
+
# Note: We're using file-level entities for the FileNode, so we don't need to merge from chunks
|
| 507 |
+
# The chunks already have their entities set correctly
|
| 508 |
+
|
| 509 |
+
file_node = FileNode(
|
| 510 |
+
id=file_path,
|
| 511 |
+
name=parts[-1],
|
| 512 |
+
path=file_path,
|
| 513 |
+
node_type='file',
|
| 514 |
+
content=level1_node_contents[file_path],
|
| 515 |
+
declared_entities=file_declared_entities,
|
| 516 |
+
called_entities=file_called_entities,
|
| 517 |
+
language=get_language_from_filename(file_path),
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
self.logger.debug(f"Adding file node {file_path} to graph.")
|
| 521 |
+
self.graph.add_node(file_path, data=file_node, level=1)
|
| 522 |
+
for chunk_id in chunk_ids:
|
| 523 |
+
self.graph.add_edge(file_path, chunk_id, relation='contains')
|
| 524 |
+
|
| 525 |
+
file_info[file_path] = {
|
| 526 |
+
'declared_entities': file_declared_entities,
|
| 527 |
+
'called_entities': file_called_entities,
|
| 528 |
+
'chunk_ids': chunk_ids,
|
| 529 |
+
'parts': parts,
|
| 530 |
+
}
|
| 531 |
+
self.logger.info(f"File node {file_path} added to graph with {len(chunk_ids)} chunks.")
|
| 532 |
+
|
| 533 |
+
self.logger.info("All file nodes have been created.")
|
| 534 |
+
return file_info
|
| 535 |
+
|
| 536 |
+
def _create_directory_nodes(self, file_info):
|
| 537 |
+
self.logger.info("Starting directory node creation.")
|
| 538 |
+
"""
|
| 539 |
+
For each directory, aggregate file information and create DirectoryNode objects.
|
| 540 |
+
|
| 541 |
+
Args:
|
| 542 |
+
file_info (dict): Mapping file_path -> file info dict.
|
| 543 |
+
|
| 544 |
+
Returns:
|
| 545 |
+
dict: Mapping dir_path -> aggregated entity info.
|
| 546 |
+
"""
|
| 547 |
+
|
| 548 |
+
def merge_entities(target, source):
|
| 549 |
+
# Merge entity lists, avoiding duplicates by (name, type)
|
| 550 |
+
existing = set((e.get('name'), e.get('type')) for e in target)
|
| 551 |
+
for e in source:
|
| 552 |
+
k = (e.get('name'), e.get('type'))
|
| 553 |
+
if k not in existing:
|
| 554 |
+
target.append(e)
|
| 555 |
+
existing.add(k)
|
| 556 |
+
|
| 557 |
+
def merge_called_entities(target, source):
|
| 558 |
+
# Merge called entity lists, avoiding duplicates
|
| 559 |
+
existing = set(target)
|
| 560 |
+
for e in source:
|
| 561 |
+
if e not in existing:
|
| 562 |
+
target.append(e)
|
| 563 |
+
existing.add(e)
|
| 564 |
+
|
| 565 |
+
dir_agg = {}
|
| 566 |
+
for file_path, info in tqdm.tqdm(file_info.items(), desc="Creating directory nodes"):
|
| 567 |
+
self.logger.info(f"Processing directory nodes for file: {file_path}")
|
| 568 |
+
parts = os.path.normpath(file_path).split(os.sep)
|
| 569 |
+
file_declared_entities = info['declared_entities']
|
| 570 |
+
file_called_entities = info['called_entities']
|
| 571 |
+
current_parent = 'root'
|
| 572 |
+
path_accum = ''
|
| 573 |
+
for part in parts[:-1]: # Skip file itself
|
| 574 |
+
path_accum = os.path.join(path_accum, part) if path_accum else part
|
| 575 |
+
if path_accum not in self.graph:
|
| 576 |
+
self.logger.info(f"Adding new directory node: {path_accum}")
|
| 577 |
+
dir_node = DirectoryNode(id=path_accum, name=part, path=path_accum)
|
| 578 |
+
self.graph.add_node(path_accum, data=dir_node, level=1)
|
| 579 |
+
self.graph.add_edge(current_parent, path_accum, relation='contains')
|
| 580 |
+
if path_accum not in dir_agg:
|
| 581 |
+
dir_agg[path_accum] = {'declared_entities': [], 'called_entities': []}
|
| 582 |
+
merge_entities(dir_agg[path_accum]['declared_entities'], file_declared_entities)
|
| 583 |
+
merge_called_entities(dir_agg[path_accum]['called_entities'], file_called_entities)
|
| 584 |
+
current_parent = path_accum
|
| 585 |
+
# Connect file to its parent directory
|
| 586 |
+
self.graph.add_edge(current_parent, file_path, relation='contains')
|
| 587 |
+
self.logger.info("All directory nodes created.")
|
| 588 |
+
return dir_agg
|
| 589 |
+
|
| 590 |
+
def _aggregate_to_root(self, dir_agg):
|
| 591 |
+
self.logger.info("Aggregating directory information to root node.")
|
| 592 |
+
"""
|
| 593 |
+
Aggregate all directory entity information to the root node.
|
| 594 |
+
|
| 595 |
+
Args:
|
| 596 |
+
dir_agg (dict): Mapping dir_path -> aggregated entity info.
|
| 597 |
+
"""
|
| 598 |
+
|
| 599 |
+
def merge_entities(target, source):
|
| 600 |
+
# Merge entity lists, avoiding duplicates by (name, type)
|
| 601 |
+
existing = set((e.get('name'), e.get('type')) for e in target)
|
| 602 |
+
for e in source:
|
| 603 |
+
k = (e.get('name'), e.get('type'))
|
| 604 |
+
if k not in existing:
|
| 605 |
+
target.append(e)
|
| 606 |
+
existing.add(k)
|
| 607 |
+
|
| 608 |
+
def merge_called_entities(target, source):
|
| 609 |
+
# Merge called entity lists, avoiding duplicates
|
| 610 |
+
existing = set(target)
|
| 611 |
+
for e in source:
|
| 612 |
+
if e not in existing:
|
| 613 |
+
target.append(e)
|
| 614 |
+
existing.add(e)
|
| 615 |
+
|
| 616 |
+
root_node = Node(id='root', name='root', node_type='root')
|
| 617 |
+
self.graph.add_node('root', data=root_node, level=0)
|
| 618 |
+
root_declared_entities = []
|
| 619 |
+
root_called_entities = []
|
| 620 |
+
for dir_path, agg in tqdm.tqdm(dir_agg.items(), desc="Aggregating to root"):
|
| 621 |
+
node = self.graph.nodes[dir_path]['data']
|
| 622 |
+
if not hasattr(node, 'declared_entities'):
|
| 623 |
+
node.declared_entities = []
|
| 624 |
+
if not hasattr(node, 'called_entities'):
|
| 625 |
+
node.called_entities = []
|
| 626 |
+
merge_entities(node.declared_entities, agg['declared_entities'])
|
| 627 |
+
merge_called_entities(node.called_entities, agg['called_entities'])
|
| 628 |
+
merge_entities(root_declared_entities, agg['declared_entities'])
|
| 629 |
+
merge_called_entities(root_called_entities, agg['called_entities'])
|
| 630 |
+
if not hasattr(root_node, 'declared_entities'):
|
| 631 |
+
root_node.declared_entities = []
|
| 632 |
+
if not hasattr(root_node, 'called_entities'):
|
| 633 |
+
root_node.called_entities = []
|
| 634 |
+
merge_entities(root_node.declared_entities, root_declared_entities)
|
| 635 |
+
merge_called_entities(root_node.called_entities, root_called_entities)
|
| 636 |
+
self.logger.info("Aggregation to root node complete.")
|
| 637 |
+
|
| 638 |
+
def _build_relationships(self):
|
| 639 |
+
self.logger.info("Building relationships between chunk nodes based on entities.")
|
| 640 |
+
"""
|
| 641 |
+
Build relationships between chunk nodes and entity nodes based on self.entities.
|
| 642 |
+
For each entity in self.entities:
|
| 643 |
+
1. Create an EntityNode with entity_name as the id
|
| 644 |
+
2. Create edges from declaring chunks to entity node (declares relationship)
|
| 645 |
+
3. Create edges from entity node to calling chunks (called_by relationship)
|
| 646 |
+
4. Resolve called entity names using aliases for better matching
|
| 647 |
+
"""
|
| 648 |
+
from .Node import EntityNode
|
| 649 |
+
edges_created = 0
|
| 650 |
+
entity_nodes_created = 0
|
| 651 |
+
|
| 652 |
+
# Build alias map for quick lookups
|
| 653 |
+
self.logger.info("Building entity alias map for call resolution...")
|
| 654 |
+
alias_map = build_entity_alias_map(self.entities)
|
| 655 |
+
self.logger.info(f"Built alias map with {len(alias_map)} entries")
|
| 656 |
+
|
| 657 |
+
# First pass: Create all entity nodes
|
| 658 |
+
for entity_name, info in tqdm.tqdm(self.entities.items(), desc="Creating entity nodes"):
|
| 659 |
+
# Entity type is stored as a list in 'type' key, get first type or empty string
|
| 660 |
+
entity_types = info.get('type', [])
|
| 661 |
+
entity_type = entity_types[0] if entity_types else ''
|
| 662 |
+
declaring_chunks = info.get('declaring_chunk_ids', [])
|
| 663 |
+
calling_chunks = info.get('calling_chunk_ids', [])
|
| 664 |
+
aliases = info.get('aliases', [])
|
| 665 |
+
|
| 666 |
+
# Create EntityNode with entity_name as id
|
| 667 |
+
entity_node = EntityNode(
|
| 668 |
+
id=entity_name,
|
| 669 |
+
name=entity_name,
|
| 670 |
+
entity_type=entity_type,
|
| 671 |
+
declaring_chunk_ids=declaring_chunks,
|
| 672 |
+
calling_chunk_ids=calling_chunks,
|
| 673 |
+
aliases=aliases
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
# Add entity node to graph
|
| 677 |
+
self.graph.add_node(entity_name, data=entity_node, level=3)
|
| 678 |
+
entity_nodes_created += 1
|
| 679 |
+
|
| 680 |
+
# Log aliases for debugging
|
| 681 |
+
if aliases:
|
| 682 |
+
self.logger.debug(f"Created EntityNode '{entity_name}' with aliases: {aliases}")
|
| 683 |
+
|
| 684 |
+
# Create edges from declaring chunks to entity node
|
| 685 |
+
for declarer_id in declaring_chunks:
|
| 686 |
+
if declarer_id in self.graph:
|
| 687 |
+
self.graph.add_edge(declarer_id, entity_name, relation='declares')
|
| 688 |
+
edges_created += 1
|
| 689 |
+
|
| 690 |
+
# Create edges from entity node to calling chunks
|
| 691 |
+
for caller_id in calling_chunks:
|
| 692 |
+
if caller_id in self.graph and caller_id not in declaring_chunks:
|
| 693 |
+
self.graph.add_edge(entity_name, caller_id, relation='called_by')
|
| 694 |
+
edges_created += 1
|
| 695 |
+
|
| 696 |
+
# Second pass: Resolve unmatched entity calls using alias matching
|
| 697 |
+
self.logger.info("Resolving entity calls using alias matching...")
|
| 698 |
+
resolved_calls = 0
|
| 699 |
+
|
| 700 |
+
for entity_name, info in tqdm.tqdm(self.entities.items(), desc="Resolving entity calls"):
|
| 701 |
+
# Skip entities that already have declarations (they were matched directly)
|
| 702 |
+
if info.get('declaring_chunk_ids'):
|
| 703 |
+
continue
|
| 704 |
+
|
| 705 |
+
# Try to resolve this called entity to a declared entity using aliases
|
| 706 |
+
resolved_name = resolve_entity_call(entity_name, alias_map)
|
| 707 |
+
|
| 708 |
+
if resolved_name and resolved_name != entity_name:
|
| 709 |
+
# Found a match! Update the calling_chunk_ids of the resolved entity
|
| 710 |
+
calling_chunks = info.get('calling_chunk_ids', [])
|
| 711 |
+
|
| 712 |
+
if resolved_name in self.entities:
|
| 713 |
+
for caller_id in calling_chunks:
|
| 714 |
+
if caller_id in self.graph:
|
| 715 |
+
# Add edge from resolved entity to calling chunk
|
| 716 |
+
if not self.graph.has_edge(resolved_name, caller_id):
|
| 717 |
+
self.graph.add_edge(resolved_name, caller_id, relation='called_by')
|
| 718 |
+
edges_created += 1
|
| 719 |
+
resolved_calls += 1
|
| 720 |
+
self.logger.debug(f"Resolved call: '{entity_name}' -> '{resolved_name}' in chunk {caller_id}")
|
| 721 |
+
|
| 722 |
+
self.logger.info(f"_build_relationships: Created {entity_nodes_created} entity nodes, "
|
| 723 |
+
f"{edges_created} edges, and resolved {resolved_calls} entity calls using aliases.")
|
| 724 |
+
|
| 725 |
+
def get_entity_by_alias(self, alias: str) -> Optional[str]:
|
| 726 |
+
"""
|
| 727 |
+
Get the canonical entity name for a given alias.
|
| 728 |
+
|
| 729 |
+
Args:
|
| 730 |
+
alias: An alias of an entity (e.g., 'MyClass' or 'module.MyClass')
|
| 731 |
+
|
| 732 |
+
Returns:
|
| 733 |
+
Canonical entity name if found, None otherwise
|
| 734 |
+
"""
|
| 735 |
+
alias_map = build_entity_alias_map(self.entities)
|
| 736 |
+
return alias_map.get(alias)
|
| 737 |
+
|
| 738 |
+
def resolve_entity_references(self) -> Dict[str, List[str]]:
|
| 739 |
+
"""
|
| 740 |
+
Resolve all entity references in the knowledge graph using aliases.
|
| 741 |
+
Returns a mapping of unresolved entity calls to their potential matches.
|
| 742 |
+
|
| 743 |
+
Returns:
|
| 744 |
+
Dictionary mapping called entity names to list of potential canonical matches
|
| 745 |
+
"""
|
| 746 |
+
alias_map = build_entity_alias_map(self.entities)
|
| 747 |
+
resolutions = {}
|
| 748 |
+
|
| 749 |
+
for entity_name, info in self.entities.items():
|
| 750 |
+
# Only look at entities that are called but not declared
|
| 751 |
+
if not info.get('declaring_chunk_ids') and info.get('calling_chunk_ids'):
|
| 752 |
+
resolved = resolve_entity_call(entity_name, alias_map)
|
| 753 |
+
if resolved:
|
| 754 |
+
resolutions[entity_name] = resolved
|
| 755 |
+
|
| 756 |
+
return resolutions
|
| 757 |
+
|
| 758 |
+
def print_tree(self, max_depth=None, start_node_id='root', level=0, prefix=""):
|
| 759 |
+
"""
|
| 760 |
+
Print the repository tree structure using the graph with 'contains' edges.
|
| 761 |
+
|
| 762 |
+
Args:
|
| 763 |
+
max_depth (int, optional): Maximum depth to print. None = unlimited.
|
| 764 |
+
start_node_id (str): ID of the node to start from. Default is 'root'.
|
| 765 |
+
level (int): Internal use only (used for recursion).
|
| 766 |
+
prefix (str): Internal use only (used for formatting output).
|
| 767 |
+
"""
|
| 768 |
+
if max_depth is not None and level > max_depth:
|
| 769 |
+
self.logger.debug(f"Max depth {max_depth} reached at node {start_node_id}.")
|
| 770 |
+
return
|
| 771 |
+
|
| 772 |
+
if start_node_id not in self.graph:
|
| 773 |
+
self.logger.warning(f"Start node '{start_node_id}' not found in graph.")
|
| 774 |
+
return
|
| 775 |
+
|
| 776 |
+
try:
|
| 777 |
+
node_data = self[start_node_id]
|
| 778 |
+
except KeyError as e:
|
| 779 |
+
self.logger.error(f"KeyError when accessing node {start_node_id}: {e}")
|
| 780 |
+
self.logger.error(f"Available node attributes: {list(self.graph.nodes[start_node_id].keys())}")
|
| 781 |
+
# Use a fallback approach if 'data' is missing
|
| 782 |
+
if 'data' not in self.graph.nodes[start_node_id]:
|
| 783 |
+
self.logger.warning(f"Node {start_node_id} has no 'data' attribute, using node itself")
|
| 784 |
+
# Create a fallback node if 'data' is missing
|
| 785 |
+
if start_node_id == 'root':
|
| 786 |
+
# Create a default root node
|
| 787 |
+
node_data = Node(id='root', name='root', node_type='root')
|
| 788 |
+
# Update the graph node with the fallback data
|
| 789 |
+
self.graph.nodes[start_node_id]['data'] = node_data
|
| 790 |
+
else:
|
| 791 |
+
# Try to infer node type from ID or structure
|
| 792 |
+
name = start_node_id.split('/')[-1] if '/' in start_node_id else start_node_id
|
| 793 |
+
if '_' in start_node_id and start_node_id.split('_')[-1].isdigit():
|
| 794 |
+
# Looks like a chunk ID
|
| 795 |
+
node_data = ChunkNode(id=start_node_id, name=name, node_type='chunk')
|
| 796 |
+
elif '.' in name:
|
| 797 |
+
# Looks like a file
|
| 798 |
+
node_data = FileNode(id=start_node_id, name=name, node_type='file', path=start_node_id)
|
| 799 |
+
else:
|
| 800 |
+
# Fallback to directory or generic node
|
| 801 |
+
node_data = DirectoryNode(id=start_node_id, name=name, node_type='directory',
|
| 802 |
+
path=start_node_id)
|
| 803 |
+
# Update the graph node with the fallback data
|
| 804 |
+
self.graph.nodes[start_node_id]['data'] = node_data
|
| 805 |
+
return
|
| 806 |
+
|
| 807 |
+
# Choose icon based on node type
|
| 808 |
+
if node_data.node_type == 'file':
|
| 809 |
+
node_symbol = "π"
|
| 810 |
+
elif node_data.node_type == 'chunk':
|
| 811 |
+
node_symbol = "π"
|
| 812 |
+
elif node_data.node_type == 'root':
|
| 813 |
+
node_symbol = "π"
|
| 814 |
+
elif node_data.node_type == 'directory':
|
| 815 |
+
node_symbol = "π"
|
| 816 |
+
else:
|
| 817 |
+
node_symbol = "π¦"
|
| 818 |
+
|
| 819 |
+
if level == 0:
|
| 820 |
+
print(f"{node_symbol} {node_data.name} ({node_data.node_type})")
|
| 821 |
+
else:
|
| 822 |
+
print(f"{prefix}βββ {node_symbol} {node_data.name} ({node_data.node_type})")
|
| 823 |
+
|
| 824 |
+
# Get children via 'contains' edges
|
| 825 |
+
children = [
|
| 826 |
+
child for child in self.graph.successors(start_node_id)
|
| 827 |
+
if self.graph.edges[start_node_id, child].get('relation') == 'contains'
|
| 828 |
+
]
|
| 829 |
+
|
| 830 |
+
child_count = len(children)
|
| 831 |
+
for i, child_id in enumerate(children):
|
| 832 |
+
is_last = i == child_count - 1
|
| 833 |
+
new_prefix = prefix + (" " if is_last else "β ")
|
| 834 |
+
self.print_tree(max_depth, start_node_id=child_id, level=level + 1, prefix=new_prefix)
|
| 835 |
+
|
| 836 |
+
def to_dict(self):
|
| 837 |
+
self.logger.info("Serializing graph to dictionary.")
|
| 838 |
+
from .Node import EntityNode
|
| 839 |
+
graph_data = {
|
| 840 |
+
'nodes': [],
|
| 841 |
+
'edges': []
|
| 842 |
+
}
|
| 843 |
+
|
| 844 |
+
for node_id, node_attrs in tqdm.tqdm(self.graph.nodes(data=True), desc="Serializing nodes"):
|
| 845 |
+
if 'data' not in node_attrs:
|
| 846 |
+
self.logger.warning(f"Node {node_id} has no 'data' attribute, skipping in serialization")
|
| 847 |
+
continue
|
| 848 |
+
|
| 849 |
+
node = node_attrs['data']
|
| 850 |
+
node_dict = {
|
| 851 |
+
'id': node.id or node_id,
|
| 852 |
+
'class': node.__class__.__name__,
|
| 853 |
+
'data': {
|
| 854 |
+
'id': node.id or node_id,
|
| 855 |
+
'name': node.name,
|
| 856 |
+
'node_type': node.node_type,
|
| 857 |
+
'description': getattr(node, 'description', ''),
|
| 858 |
+
'declared_entities': list(getattr(node, 'declared_entities', [])),
|
| 859 |
+
'called_entities': list(getattr(node, 'called_entities', [])),
|
| 860 |
+
}
|
| 861 |
+
}
|
| 862 |
+
|
| 863 |
+
# FileNode-specific
|
| 864 |
+
if isinstance(node, FileNode):
|
| 865 |
+
node_dict['data']['path'] = node.path
|
| 866 |
+
node_dict['data']['content'] = node.content
|
| 867 |
+
node_dict['data']['language'] = getattr(node, 'language', '')
|
| 868 |
+
|
| 869 |
+
# ChunkNode-specific
|
| 870 |
+
if isinstance(node, ChunkNode):
|
| 871 |
+
node_dict['data']['order_in_file'] = getattr(node, 'order_in_file', 0)
|
| 872 |
+
node_dict['data']['embedding'] = getattr(node, 'embedding', None)
|
| 873 |
+
|
| 874 |
+
# EntityNode-specific
|
| 875 |
+
if isinstance(node, EntityNode):
|
| 876 |
+
node_dict['data']['entity_type'] = getattr(node, 'entity_type', '')
|
| 877 |
+
node_dict['data']['declaring_chunk_ids'] = list(getattr(node, 'declaring_chunk_ids', []))
|
| 878 |
+
node_dict['data']['calling_chunk_ids'] = list(getattr(node, 'calling_chunk_ids', []))
|
| 879 |
+
node_dict['data']['aliases'] = list(getattr(node, 'aliases', []))
|
| 880 |
+
|
| 881 |
+
graph_data['nodes'].append(node_dict)
|
| 882 |
+
|
| 883 |
+
for u, v, attrs in tqdm.tqdm(self.graph.edges(data=True), desc="Serializing edges"):
|
| 884 |
+
edge_data = {
|
| 885 |
+
'source': u,
|
| 886 |
+
'target': v,
|
| 887 |
+
'relation': attrs.get('relation', '')
|
| 888 |
+
}
|
| 889 |
+
if 'entities' in attrs:
|
| 890 |
+
edge_data['entities'] = list(attrs['entities'])
|
| 891 |
+
graph_data['edges'].append(edge_data)
|
| 892 |
+
|
| 893 |
+
self.logger.info("Serialization complete.")
|
| 894 |
+
return graph_data
|
| 895 |
+
|
| 896 |
+
@classmethod
|
| 897 |
+
def from_dict(cls, data_dict, index_nodes: bool = True, use_embed: bool = True,
|
| 898 |
+
model_service_kwargs: Optional[dict] = None, code_index_kwargs: Optional[dict] = None):
|
| 899 |
+
# ...existing code...
|
| 900 |
+
instance = cls.__new__(cls) # bypass __init__
|
| 901 |
+
instance._initialize(model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs)
|
| 902 |
+
|
| 903 |
+
instance.logger.info("Deserializing graph from dictionary.")
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
node_classes = {
|
| 907 |
+
'Node': Node,
|
| 908 |
+
'FileNode': FileNode,
|
| 909 |
+
'ChunkNode': ChunkNode,
|
| 910 |
+
'DirectoryNode': DirectoryNode,
|
| 911 |
+
'EntityNode': EntityNode,
|
| 912 |
+
}
|
| 913 |
+
|
| 914 |
+
# Create a root node if not present in the data
|
| 915 |
+
root_found = any(node_data['id'] == 'root' for node_data in data_dict['nodes'])
|
| 916 |
+
if not root_found:
|
| 917 |
+
instance.logger.warning("Root node not found in the data, creating one")
|
| 918 |
+
root_node = Node(id='root', name='root', node_type='root')
|
| 919 |
+
instance.graph.add_node('root', data=root_node, level=0)
|
| 920 |
+
|
| 921 |
+
# --- Rebuild nodes ---
|
| 922 |
+
for node_data in tqdm.tqdm(data_dict['nodes'], desc="Rebuilding nodes"):
|
| 923 |
+
cls_name = node_data['class']
|
| 924 |
+
node_cls = node_classes.get(cls_name, Node)
|
| 925 |
+
kwargs = node_data['data']
|
| 926 |
+
|
| 927 |
+
# Ensure ID is properly set
|
| 928 |
+
if not kwargs.get('id'):
|
| 929 |
+
kwargs['id'] = node_data['id']
|
| 930 |
+
|
| 931 |
+
# Always use lists for declared_entities and called_entities
|
| 932 |
+
kwargs['declared_entities'] = list(kwargs.get('declared_entities', []))
|
| 933 |
+
kwargs['called_entities'] = list(kwargs.get('called_entities', []))
|
| 934 |
+
|
| 935 |
+
# FileNode-specific
|
| 936 |
+
if node_cls in (FileNode, ChunkNode):
|
| 937 |
+
kwargs.setdefault('path', '')
|
| 938 |
+
kwargs.setdefault('content', '')
|
| 939 |
+
kwargs.setdefault('language', '')
|
| 940 |
+
if node_cls == ChunkNode:
|
| 941 |
+
kwargs.setdefault('order_in_file', 0)
|
| 942 |
+
kwargs.setdefault('embedding', [])
|
| 943 |
+
# EntityNode-specific
|
| 944 |
+
if node_cls == EntityNode:
|
| 945 |
+
kwargs.setdefault('entity_type', '')
|
| 946 |
+
kwargs.setdefault('declaring_chunk_ids', [])
|
| 947 |
+
kwargs.setdefault('calling_chunk_ids', [])
|
| 948 |
+
kwargs.setdefault('aliases', [])
|
| 949 |
+
|
| 950 |
+
node_instance = node_cls(**kwargs)
|
| 951 |
+
instance.graph.add_node(node_data['id'], data=node_instance, level=instance._infer_level(node_instance))
|
| 952 |
+
|
| 953 |
+
# --- Rebuild edges ---
|
| 954 |
+
for edge in tqdm.tqdm(data_dict['edges'], desc="Rebuilding edges"):
|
| 955 |
+
source = edge['source']
|
| 956 |
+
target = edge['target']
|
| 957 |
+
if source in instance.graph and target in instance.graph:
|
| 958 |
+
edge_kwargs = {'relation': edge.get('relation', '')}
|
| 959 |
+
if 'entities' in edge:
|
| 960 |
+
edge_kwargs['entities'] = list(edge['entities'])
|
| 961 |
+
instance.graph.add_edge(source, target, **edge_kwargs)
|
| 962 |
+
else:
|
| 963 |
+
instance.logger.warning(f"Cannot add edge {source} -> {target}, nodes don't exist")
|
| 964 |
+
|
| 965 |
+
# --- Rebuild instance.entities ---
|
| 966 |
+
instance.entities = {}
|
| 967 |
+
for node_id, node_attrs in tqdm.tqdm(instance.graph.nodes(data=True), desc="Rebuilding entities"):
|
| 968 |
+
node = node_attrs['data']
|
| 969 |
+
declared_entities = getattr(node, 'declared_entities', [])
|
| 970 |
+
called_entities = getattr(node, 'called_entities', [])
|
| 971 |
+
for entity in declared_entities:
|
| 972 |
+
if isinstance(entity, dict):
|
| 973 |
+
name = entity.get('name')
|
| 974 |
+
else:
|
| 975 |
+
name = entity
|
| 976 |
+
if not name:
|
| 977 |
+
continue
|
| 978 |
+
if name not in instance.entities:
|
| 979 |
+
instance.entities[name] = {
|
| 980 |
+
"declaring_chunk_ids": [],
|
| 981 |
+
"calling_chunk_ids": [],
|
| 982 |
+
"type": [],
|
| 983 |
+
"dtype": None
|
| 984 |
+
}
|
| 985 |
+
# Only add node_id if it is a ChunkNode
|
| 986 |
+
if node_id not in instance.entities[name]["declaring_chunk_ids"]:
|
| 987 |
+
if node_id in instance.graph and isinstance(instance.graph.nodes[node_id]["data"], ChunkNode):
|
| 988 |
+
instance.entities[name]["declaring_chunk_ids"].append(node_id)
|
| 989 |
+
if isinstance(entity, dict):
|
| 990 |
+
entity_type = entity.get("type")
|
| 991 |
+
if entity_type and entity_type not in instance.entities[name]["type"]:
|
| 992 |
+
instance.entities[name]["type"].append(entity_type)
|
| 993 |
+
dtype = entity.get("dtype")
|
| 994 |
+
if dtype:
|
| 995 |
+
instance.entities[name]["dtype"] = dtype
|
| 996 |
+
for called_name in called_entities:
|
| 997 |
+
if not called_name:
|
| 998 |
+
continue
|
| 999 |
+
if called_name not in instance.entities:
|
| 1000 |
+
instance.entities[called_name] = {
|
| 1001 |
+
"declaring_chunk_ids": [],
|
| 1002 |
+
"calling_chunk_ids": [],
|
| 1003 |
+
"type": [],
|
| 1004 |
+
"dtype": None
|
| 1005 |
+
}
|
| 1006 |
+
if node_id not in instance.entities[called_name]["calling_chunk_ids"]:
|
| 1007 |
+
if node_id in instance.graph and isinstance(instance.graph.nodes[node_id]["data"], ChunkNode):
|
| 1008 |
+
instance.entities[called_name]["calling_chunk_ids"].append(node_id)
|
| 1009 |
+
|
| 1010 |
+
if index_nodes:
|
| 1011 |
+
instance.logger.info("Building code index after deserialization.")
|
| 1012 |
+
# Merge use_embed with code_index_kwargs, avoiding duplicate keyword arguments
|
| 1013 |
+
code_idx_kwargs = code_index_kwargs or {}
|
| 1014 |
+
if 'use_embed' not in code_idx_kwargs:
|
| 1015 |
+
code_idx_kwargs['use_embed'] = use_embed
|
| 1016 |
+
instance.code_index = CodeIndex(list(instance), model_service=instance.model_service, **code_idx_kwargs)
|
| 1017 |
+
|
| 1018 |
+
instance.logger.info("Deserialization complete.")
|
| 1019 |
+
return instance
|
| 1020 |
+
|
| 1021 |
+
def _infer_level(self, node):
|
| 1022 |
+
"""Infer the level of a node based on its type"""
|
| 1023 |
+
if node.node_type == 'root':
|
| 1024 |
+
return 0
|
| 1025 |
+
elif node.node_type in ('file', 'directory'):
|
| 1026 |
+
return 1
|
| 1027 |
+
elif node.node_type == 'chunk':
|
| 1028 |
+
return 2
|
| 1029 |
+
return 1 # Default level
|
| 1030 |
+
|
| 1031 |
+
def save_graph_to_file(self, filepath: str):
|
| 1032 |
+
self.logger.info(f"Saving graph to file: {filepath}")
|
| 1033 |
+
with open(filepath, 'w') as f:
|
| 1034 |
+
json.dump(self.to_dict(), f, indent=2)
|
| 1035 |
+
self.logger.info("Graph saved successfully.")
|
| 1036 |
+
|
| 1037 |
+
@classmethod
|
| 1038 |
+
def load_graph_from_file(cls, filepath: str, index_nodes=True, use_embed: bool = True,
|
| 1039 |
+
model_service_kwargs: Optional[dict] = None, code_index_kwargs: Optional[dict] = None):
|
| 1040 |
+
if model_service_kwargs is None:
|
| 1041 |
+
model_service_kwargs = {}
|
| 1042 |
+
with open(filepath, 'r') as f:
|
| 1043 |
+
data = json.load(f)
|
| 1044 |
+
logging.getLogger(LOGGER_NAME).info(f"Loaded graph data from file: {filepath}")
|
| 1045 |
+
return cls.from_dict(data, use_embed=use_embed, index_nodes=index_nodes,
|
| 1046 |
+
model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs)
|
| 1047 |
+
|
| 1048 |
+
def to_hf_dataset(
|
| 1049 |
+
self,
|
| 1050 |
+
repo_id: str,
|
| 1051 |
+
save_embeddings: bool = True,
|
| 1052 |
+
private: bool = False,
|
| 1053 |
+
token: Optional[str] = None,
|
| 1054 |
+
commit_message: Optional[str] = None,
|
| 1055 |
+
):
|
| 1056 |
+
"""
|
| 1057 |
+
Save the knowledge graph to a HuggingFace dataset on the Hub.
|
| 1058 |
+
|
| 1059 |
+
The graph is serialized into two splits:
|
| 1060 |
+
- 'nodes': Contains all node data
|
| 1061 |
+
- 'edges': Contains all edge relationships
|
| 1062 |
+
|
| 1063 |
+
Args:
|
| 1064 |
+
repo_id (str): The HuggingFace dataset repository ID (e.g., 'username/dataset-name')
|
| 1065 |
+
save_embeddings (bool): If True, saves embedding vectors for chunk nodes.
|
| 1066 |
+
If False, embeddings are excluded to reduce dataset size.
|
| 1067 |
+
private (bool): Whether the dataset should be private. Defaults to False.
|
| 1068 |
+
token (str, optional): HuggingFace API token. If not provided, uses the token
|
| 1069 |
+
from huggingface_hub login or HF_TOKEN environment variable.
|
| 1070 |
+
commit_message (str, optional): Custom commit message for the upload.
|
| 1071 |
+
|
| 1072 |
+
Returns:
|
| 1073 |
+
str: URL of the uploaded dataset
|
| 1074 |
+
"""
|
| 1075 |
+
try:
|
| 1076 |
+
from datasets import Dataset, DatasetDict
|
| 1077 |
+
from huggingface_hub import HfApi
|
| 1078 |
+
except ImportError:
|
| 1079 |
+
raise ImportError(
|
| 1080 |
+
"huggingface_hub and datasets are required for HuggingFace integration. "
|
| 1081 |
+
"Install them with: pip install huggingface_hub datasets"
|
| 1082 |
+
)
|
| 1083 |
+
|
| 1084 |
+
self.logger.info(f"Preparing to save knowledge graph to HuggingFace dataset: {repo_id}")
|
| 1085 |
+
self.logger.info(f"save_embeddings={save_embeddings}")
|
| 1086 |
+
|
| 1087 |
+
# Serialize nodes
|
| 1088 |
+
nodes_data = []
|
| 1089 |
+
for node_id, node_attrs in tqdm.tqdm(self.graph.nodes(data=True), desc="Serializing nodes for HF dataset"):
|
| 1090 |
+
if 'data' not in node_attrs:
|
| 1091 |
+
self.logger.warning(f"Node {node_id} has no 'data' attribute, skipping")
|
| 1092 |
+
continue
|
| 1093 |
+
|
| 1094 |
+
node = node_attrs['data']
|
| 1095 |
+
node_record = {
|
| 1096 |
+
'node_id': node.id or node_id,
|
| 1097 |
+
'node_class': node.__class__.__name__,
|
| 1098 |
+
'name': node.name,
|
| 1099 |
+
'node_type': node.node_type,
|
| 1100 |
+
'description': getattr(node, 'description', '') or '',
|
| 1101 |
+
'declared_entities': json.dumps(list(getattr(node, 'declared_entities', []))),
|
| 1102 |
+
'called_entities': json.dumps(list(getattr(node, 'called_entities', []))),
|
| 1103 |
+
}
|
| 1104 |
+
|
| 1105 |
+
# FileNode-specific fields
|
| 1106 |
+
if isinstance(node, FileNode):
|
| 1107 |
+
node_record['path'] = node.path
|
| 1108 |
+
node_record['content'] = node.content
|
| 1109 |
+
node_record['language'] = getattr(node, 'language', '')
|
| 1110 |
+
else:
|
| 1111 |
+
node_record['path'] = ''
|
| 1112 |
+
node_record['content'] = ''
|
| 1113 |
+
node_record['language'] = ''
|
| 1114 |
+
|
| 1115 |
+
# ChunkNode-specific fields
|
| 1116 |
+
if isinstance(node, ChunkNode):
|
| 1117 |
+
node_record['order_in_file'] = getattr(node, 'order_in_file', 0)
|
| 1118 |
+
if save_embeddings:
|
| 1119 |
+
embedding = getattr(node, 'embedding', None)
|
| 1120 |
+
node_record['embedding'] = json.dumps(embedding if embedding is not None else [])
|
| 1121 |
+
else:
|
| 1122 |
+
node_record['embedding'] = json.dumps([])
|
| 1123 |
+
else:
|
| 1124 |
+
node_record['order_in_file'] = -1
|
| 1125 |
+
node_record['embedding'] = json.dumps([])
|
| 1126 |
+
|
| 1127 |
+
# EntityNode-specific fields
|
| 1128 |
+
if isinstance(node, EntityNode):
|
| 1129 |
+
node_record['entity_type'] = getattr(node, 'entity_type', '')
|
| 1130 |
+
node_record['declaring_chunk_ids'] = json.dumps(list(getattr(node, 'declaring_chunk_ids', [])))
|
| 1131 |
+
node_record['calling_chunk_ids'] = json.dumps(list(getattr(node, 'calling_chunk_ids', [])))
|
| 1132 |
+
node_record['aliases'] = json.dumps(list(getattr(node, 'aliases', [])))
|
| 1133 |
+
else:
|
| 1134 |
+
node_record['entity_type'] = ''
|
| 1135 |
+
node_record['declaring_chunk_ids'] = json.dumps([])
|
| 1136 |
+
node_record['calling_chunk_ids'] = json.dumps([])
|
| 1137 |
+
node_record['aliases'] = json.dumps([])
|
| 1138 |
+
|
| 1139 |
+
nodes_data.append(node_record)
|
| 1140 |
+
|
| 1141 |
+
# Serialize edges
|
| 1142 |
+
edges_data = []
|
| 1143 |
+
for source, target, attrs in tqdm.tqdm(self.graph.edges(data=True), desc="Serializing edges for HF dataset"):
|
| 1144 |
+
edge_record = {
|
| 1145 |
+
'source': source,
|
| 1146 |
+
'target': target,
|
| 1147 |
+
'relation': attrs.get('relation', ''),
|
| 1148 |
+
'entities': json.dumps(list(attrs.get('entities', []))) if 'entities' in attrs else json.dumps([])
|
| 1149 |
+
}
|
| 1150 |
+
edges_data.append(edge_record)
|
| 1151 |
+
|
| 1152 |
+
# Create datasets
|
| 1153 |
+
nodes_dataset = Dataset.from_list(nodes_data)
|
| 1154 |
+
edges_dataset = Dataset.from_list(edges_data)
|
| 1155 |
+
|
| 1156 |
+
self.logger.info(f"Created dataset with {len(nodes_data)} nodes and {len(edges_data)} edges")
|
| 1157 |
+
|
| 1158 |
+
# Push to Hub - nodes and edges are pushed separately as different configs
|
| 1159 |
+
# because they have different schemas
|
| 1160 |
+
if commit_message is None:
|
| 1161 |
+
base_commit_message = f"Upload knowledge graph ({len(nodes_data)} nodes, {len(edges_data)} edges)"
|
| 1162 |
+
if not save_embeddings:
|
| 1163 |
+
base_commit_message += " [embeddings excluded]"
|
| 1164 |
+
else:
|
| 1165 |
+
base_commit_message = commit_message
|
| 1166 |
+
|
| 1167 |
+
self.logger.info(f"Pushing nodes dataset to HuggingFace Hub: {repo_id}")
|
| 1168 |
+
nodes_dataset.push_to_hub(
|
| 1169 |
+
repo_id=repo_id,
|
| 1170 |
+
config_name="nodes",
|
| 1171 |
+
private=private,
|
| 1172 |
+
token=token,
|
| 1173 |
+
commit_message=f"{base_commit_message} - nodes"
|
| 1174 |
+
)
|
| 1175 |
+
|
| 1176 |
+
self.logger.info(f"Pushing edges dataset to HuggingFace Hub: {repo_id}")
|
| 1177 |
+
edges_dataset.push_to_hub(
|
| 1178 |
+
repo_id=repo_id,
|
| 1179 |
+
config_name="edges",
|
| 1180 |
+
private=private,
|
| 1181 |
+
token=token,
|
| 1182 |
+
commit_message=f"{base_commit_message} - edges"
|
| 1183 |
+
)
|
| 1184 |
+
|
| 1185 |
+
url = f"https://huggingface.co/datasets/{repo_id}"
|
| 1186 |
+
self.logger.info(f"Dataset successfully uploaded to: {url}")
|
| 1187 |
+
return url
|
| 1188 |
+
|
| 1189 |
+
@classmethod
|
| 1190 |
+
def from_hf_dataset(
|
| 1191 |
+
cls,
|
| 1192 |
+
repo_id: str,
|
| 1193 |
+
index_nodes: bool = True,
|
| 1194 |
+
use_embed: bool = True,
|
| 1195 |
+
model_service_kwargs: Optional[dict] = None,
|
| 1196 |
+
code_index_kwargs: Optional[dict] = None,
|
| 1197 |
+
token: Optional[str] = None,
|
| 1198 |
+
revision: Optional[str] = None,
|
| 1199 |
+
):
|
| 1200 |
+
"""
|
| 1201 |
+
Load a knowledge graph from a HuggingFace dataset on the Hub.
|
| 1202 |
+
|
| 1203 |
+
Args:
|
| 1204 |
+
repo_id (str): The HuggingFace dataset repository ID (e.g., 'username/dataset-name')
|
| 1205 |
+
index_nodes (bool): Whether to build a code index after loading. Defaults to True.
|
| 1206 |
+
use_embed (bool): Whether to use existing embeddings from the dataset. Defaults to True.
|
| 1207 |
+
model_service_kwargs (dict, optional): Arguments for the model service.
|
| 1208 |
+
code_index_kwargs (dict, optional): Arguments for the code index.
|
| 1209 |
+
token (str, optional): HuggingFace API token for private datasets.
|
| 1210 |
+
revision (str, optional): Git revision (branch, tag, or commit) to load from.
|
| 1211 |
+
|
| 1212 |
+
Returns:
|
| 1213 |
+
RepoKnowledgeGraph: The loaded knowledge graph instance.
|
| 1214 |
+
"""
|
| 1215 |
+
try:
|
| 1216 |
+
from datasets import load_dataset
|
| 1217 |
+
except ImportError:
|
| 1218 |
+
raise ImportError(
|
| 1219 |
+
"datasets library is required for HuggingFace integration. "
|
| 1220 |
+
"Install it with: pip install datasets"
|
| 1221 |
+
)
|
| 1222 |
+
|
| 1223 |
+
if model_service_kwargs is None:
|
| 1224 |
+
model_service_kwargs = {}
|
| 1225 |
+
|
| 1226 |
+
logger = logging.getLogger(LOGGER_NAME)
|
| 1227 |
+
logger.info(f"Loading knowledge graph from HuggingFace dataset: {repo_id}")
|
| 1228 |
+
|
| 1229 |
+
# Load dataset from Hub - nodes and edges are stored as separate configs
|
| 1230 |
+
logger.info("Loading nodes config...")
|
| 1231 |
+
nodes_dataset = load_dataset(repo_id, name="nodes", token=token, revision=revision)
|
| 1232 |
+
logger.info("Loading edges config...")
|
| 1233 |
+
edges_dataset = load_dataset(repo_id, name="edges", token=token, revision=revision)
|
| 1234 |
+
|
| 1235 |
+
# Get the train split (default split when pushing with config_name)
|
| 1236 |
+
nodes_data = nodes_dataset['train']
|
| 1237 |
+
edges_data = edges_dataset['train']
|
| 1238 |
+
|
| 1239 |
+
logger.info(f"Loaded {len(nodes_data)} nodes and {len(edges_data)} edges from dataset")
|
| 1240 |
+
|
| 1241 |
+
# Convert to the dict format expected by from_dict
|
| 1242 |
+
graph_data = {
|
| 1243 |
+
'nodes': [],
|
| 1244 |
+
'edges': []
|
| 1245 |
+
}
|
| 1246 |
+
|
| 1247 |
+
# Reconstruct nodes
|
| 1248 |
+
for record in tqdm.tqdm(nodes_data, desc="Reconstructing nodes from HF dataset"):
|
| 1249 |
+
node_dict = {
|
| 1250 |
+
'id': record['node_id'],
|
| 1251 |
+
'class': record['node_class'],
|
| 1252 |
+
'data': {
|
| 1253 |
+
'id': record['node_id'],
|
| 1254 |
+
'name': record['name'],
|
| 1255 |
+
'node_type': record['node_type'],
|
| 1256 |
+
'description': record['description'],
|
| 1257 |
+
'declared_entities': json.loads(record['declared_entities']),
|
| 1258 |
+
'called_entities': json.loads(record['called_entities']),
|
| 1259 |
+
}
|
| 1260 |
+
}
|
| 1261 |
+
|
| 1262 |
+
# FileNode-specific fields
|
| 1263 |
+
if record['node_class'] in ('FileNode', 'ChunkNode'):
|
| 1264 |
+
node_dict['data']['path'] = record['path']
|
| 1265 |
+
node_dict['data']['content'] = record['content']
|
| 1266 |
+
node_dict['data']['language'] = record['language']
|
| 1267 |
+
|
| 1268 |
+
# ChunkNode-specific fields
|
| 1269 |
+
if record['node_class'] == 'ChunkNode':
|
| 1270 |
+
node_dict['data']['order_in_file'] = record['order_in_file']
|
| 1271 |
+
embedding = json.loads(record['embedding'])
|
| 1272 |
+
# Only use embedding if use_embed is True and embedding is non-empty
|
| 1273 |
+
if use_embed and embedding:
|
| 1274 |
+
node_dict['data']['embedding'] = embedding
|
| 1275 |
+
else:
|
| 1276 |
+
node_dict['data']['embedding'] = []
|
| 1277 |
+
|
| 1278 |
+
# EntityNode-specific fields
|
| 1279 |
+
if record['node_class'] == 'EntityNode':
|
| 1280 |
+
node_dict['data']['entity_type'] = record['entity_type']
|
| 1281 |
+
node_dict['data']['declaring_chunk_ids'] = json.loads(record['declaring_chunk_ids'])
|
| 1282 |
+
node_dict['data']['calling_chunk_ids'] = json.loads(record['calling_chunk_ids'])
|
| 1283 |
+
node_dict['data']['aliases'] = json.loads(record['aliases'])
|
| 1284 |
+
|
| 1285 |
+
graph_data['nodes'].append(node_dict)
|
| 1286 |
+
|
| 1287 |
+
# Reconstruct edges
|
| 1288 |
+
for record in tqdm.tqdm(edges_data, desc="Reconstructing edges from HF dataset"):
|
| 1289 |
+
edge_dict = {
|
| 1290 |
+
'source': record['source'],
|
| 1291 |
+
'target': record['target'],
|
| 1292 |
+
'relation': record['relation'],
|
| 1293 |
+
}
|
| 1294 |
+
entities = json.loads(record['entities'])
|
| 1295 |
+
if entities:
|
| 1296 |
+
edge_dict['entities'] = entities
|
| 1297 |
+
|
| 1298 |
+
graph_data['edges'].append(edge_dict)
|
| 1299 |
+
|
| 1300 |
+
logger.info("Dataset reconstruction complete, building graph...")
|
| 1301 |
+
|
| 1302 |
+
# Use from_dict to build the graph
|
| 1303 |
+
return cls.from_dict(
|
| 1304 |
+
graph_data,
|
| 1305 |
+
index_nodes=index_nodes,
|
| 1306 |
+
use_embed=use_embed,
|
| 1307 |
+
model_service_kwargs=model_service_kwargs,
|
| 1308 |
+
code_index_kwargs=code_index_kwargs
|
| 1309 |
+
)
|
| 1310 |
+
|
| 1311 |
+
def get_neighbors(self, node_id):
|
| 1312 |
+
self.logger.debug(f"Getting neighbors for node: {node_id}")
|
| 1313 |
+
# Return all nodes that are directly connected to node_id (successors and predecessors) for any edge type
|
| 1314 |
+
neighbors = set()
|
| 1315 |
+
for n in self.graph.successors(node_id):
|
| 1316 |
+
neighbors.add(n)
|
| 1317 |
+
for n in self.graph.predecessors(node_id):
|
| 1318 |
+
neighbors.add(n)
|
| 1319 |
+
# Also include nodes connected by any edge (not just 'contains')
|
| 1320 |
+
for u, v in self.graph.edges(node_id):
|
| 1321 |
+
if u == node_id:
|
| 1322 |
+
neighbors.add(v)
|
| 1323 |
+
else:
|
| 1324 |
+
neighbors.add(u)
|
| 1325 |
+
for u, v in self.graph.in_edges(node_id):
|
| 1326 |
+
if v == node_id:
|
| 1327 |
+
neighbors.add(u)
|
| 1328 |
+
else:
|
| 1329 |
+
neighbors.add(v)
|
| 1330 |
+
return [self.graph.nodes[n]['data'] for n in neighbors if 'data' in self.graph.nodes[n]]
|
| 1331 |
+
|
| 1332 |
+
def get_previous_chunk(self, node_id: str) -> ChunkNode:
|
| 1333 |
+
self.logger.debug(f"Getting previous chunk for node: {node_id}")
|
| 1334 |
+
node = self[node_id]
|
| 1335 |
+
# Check if node is of type ChunkNode
|
| 1336 |
+
if not isinstance(node, ChunkNode):
|
| 1337 |
+
raise Exception(f'Cannot get previous chunk on node of type {type(node)}')
|
| 1338 |
+
|
| 1339 |
+
if node.order_in_file == 0:
|
| 1340 |
+
self.logger.warning(f'Cannot get previous chunk for first node')
|
| 1341 |
+
return None
|
| 1342 |
+
|
| 1343 |
+
file_path = node.path
|
| 1344 |
+
previous_chunk_id = f'{file_path}_{node.order_in_file - 1}'
|
| 1345 |
+
|
| 1346 |
+
if previous_chunk_id not in self.graph:
|
| 1347 |
+
raise Exception(f'Previous chunk {previous_chunk_id} not found in graph')
|
| 1348 |
+
|
| 1349 |
+
previous_chunk = self[previous_chunk_id]
|
| 1350 |
+
return previous_chunk
|
| 1351 |
+
|
| 1352 |
+
def get_next_chunk(self, node_id: str) -> ChunkNode:
|
| 1353 |
+
self.logger.debug(f"Getting next chunk for node: {node_id}")
|
| 1354 |
+
node = self[node_id]
|
| 1355 |
+
# Check if node is of type ChunkNode
|
| 1356 |
+
if not isinstance(node, ChunkNode):
|
| 1357 |
+
raise Exception(f'Cannot get previous chunk on node of type {type(node)}')
|
| 1358 |
+
|
| 1359 |
+
file_path = node.path
|
| 1360 |
+
next_chunk_id = f'{file_path}_{node.order_in_file + 1}'
|
| 1361 |
+
|
| 1362 |
+
if next_chunk_id not in self.graph:
|
| 1363 |
+
self.logger.warning(f'Next chunk {next_chunk_id} not found in graph, it might be the last chunk')
|
| 1364 |
+
return None
|
| 1365 |
+
previous_chunk = self[next_chunk_id]
|
| 1366 |
+
return previous_chunk
|
| 1367 |
+
|
| 1368 |
+
def get_all_chunks(self) -> List[ChunkNode]:
|
| 1369 |
+
self.logger.debug("Getting all chunk nodes.")
|
| 1370 |
+
chunk_nodes = []
|
| 1371 |
+
for node in self:
|
| 1372 |
+
if isinstance(node, ChunkNode):
|
| 1373 |
+
chunk_nodes.append(node)
|
| 1374 |
+
return chunk_nodes
|
| 1375 |
+
|
| 1376 |
+
def get_all_files(self) -> List[FileNode]:
|
| 1377 |
+
self.logger.debug("Getting all file nodes.")
|
| 1378 |
+
"""
|
| 1379 |
+
Get all FileNodes in the knowledge graph.
|
| 1380 |
+
|
| 1381 |
+
Returns:
|
| 1382 |
+
List[FileNode]: A list of FileNodes in the graph.
|
| 1383 |
+
"""
|
| 1384 |
+
file_nodes = []
|
| 1385 |
+
for node in self.graph.nodes(data=True):
|
| 1386 |
+
node_data = node[1]['data']
|
| 1387 |
+
# Check for exact FileNode type, not ChunkNode (which inherits from FileNode)
|
| 1388 |
+
if isinstance(node_data, FileNode) and node_data.node_type == 'file':
|
| 1389 |
+
file_nodes.append(node_data)
|
| 1390 |
+
return file_nodes
|
| 1391 |
+
|
| 1392 |
+
def get_chunks_of_file(self, file_node_id: str) -> List[ChunkNode]:
|
| 1393 |
+
self.logger.debug(f"Getting chunks for file node: {file_node_id}")
|
| 1394 |
+
"""
|
| 1395 |
+
Get all ChunkNodes associated with a specific FileNode.
|
| 1396 |
+
|
| 1397 |
+
Args:
|
| 1398 |
+
file_node (FileNode): The file node to get chunks for.
|
| 1399 |
+
|
| 1400 |
+
Returns:
|
| 1401 |
+
List[ChunkNode]: A list of ChunkNodes associated with the file.
|
| 1402 |
+
"""
|
| 1403 |
+
chunk_nodes = []
|
| 1404 |
+
for node in self.graph.neighbors(file_node_id):
|
| 1405 |
+
# Only include ChunkNodes that are connected by a 'contains' edge
|
| 1406 |
+
edge_data = self.graph.get_edge_data(file_node_id, node)
|
| 1407 |
+
node_data = self.graph.nodes[node]['data']
|
| 1408 |
+
if (
|
| 1409 |
+
isinstance(node_data, ChunkNode)
|
| 1410 |
+
and node_data.node_type == 'chunk'
|
| 1411 |
+
and edge_data is not None
|
| 1412 |
+
and edge_data.get('relation') == 'contains'
|
| 1413 |
+
):
|
| 1414 |
+
chunk_nodes.append(node_data)
|
| 1415 |
+
return chunk_nodes
|
| 1416 |
+
|
| 1417 |
+
def find_path(self, source_id: str, target_id: str, max_depth: int = 5) -> dict:
|
| 1418 |
+
"""
|
| 1419 |
+
Find the shortest path between two nodes in the knowledge graph.
|
| 1420 |
+
|
| 1421 |
+
Args:
|
| 1422 |
+
source_id (str): The ID of the source node.
|
| 1423 |
+
target_id (str): The ID of the target node.
|
| 1424 |
+
max_depth (int): Maximum depth to search for a path. Defaults to 5.
|
| 1425 |
+
|
| 1426 |
+
Returns:
|
| 1427 |
+
dict: A dictionary containing path information or error message.
|
| 1428 |
+
"""
|
| 1429 |
+
self.logger.debug(f"Finding path from {source_id} to {target_id} with max_depth={max_depth}")
|
| 1430 |
+
g = self.graph
|
| 1431 |
+
|
| 1432 |
+
if source_id not in g:
|
| 1433 |
+
return {"error": f"Source node '{source_id}' not found."}
|
| 1434 |
+
if target_id not in g:
|
| 1435 |
+
return {"error": f"Target node '{target_id}' not found."}
|
| 1436 |
+
|
| 1437 |
+
try:
|
| 1438 |
+
path = nx.shortest_path(g, source=source_id, target=target_id)
|
| 1439 |
+
|
| 1440 |
+
if len(path) - 1 > max_depth:
|
| 1441 |
+
return {
|
| 1442 |
+
"source_id": source_id,
|
| 1443 |
+
"target_id": target_id,
|
| 1444 |
+
"path": [],
|
| 1445 |
+
"length": len(path) - 1,
|
| 1446 |
+
"text": f"Path exists but exceeds max_depth of {max_depth} (actual length: {len(path) - 1})"
|
| 1447 |
+
}
|
| 1448 |
+
|
| 1449 |
+
# Build detailed path information
|
| 1450 |
+
path_details = []
|
| 1451 |
+
for i, node_id in enumerate(path):
|
| 1452 |
+
node = g.nodes[node_id]['data']
|
| 1453 |
+
node_info = {
|
| 1454 |
+
"node_id": node_id,
|
| 1455 |
+
"name": getattr(node, 'name', 'Unknown'),
|
| 1456 |
+
"type": getattr(node, 'node_type', 'Unknown'),
|
| 1457 |
+
"step": i
|
| 1458 |
+
}
|
| 1459 |
+
|
| 1460 |
+
# Add edge information for all but the last node
|
| 1461 |
+
if i < len(path) - 1:
|
| 1462 |
+
next_node_id = path[i + 1]
|
| 1463 |
+
edge_data = g.get_edge_data(node_id, next_node_id)
|
| 1464 |
+
node_info["edge_to_next"] = edge_data.get('relation', 'Unknown') if edge_data else 'Unknown'
|
| 1465 |
+
|
| 1466 |
+
path_details.append(node_info)
|
| 1467 |
+
|
| 1468 |
+
# Format text output
|
| 1469 |
+
text = f"Path from '{source_id}' to '{target_id}' (length: {len(path) - 1}):\n\n"
|
| 1470 |
+
for i, node_info in enumerate(path_details):
|
| 1471 |
+
text += f"{i}. {node_info['name']} ({node_info['type']})\n"
|
| 1472 |
+
text += f" Node ID: {node_info['node_id']}\n"
|
| 1473 |
+
if 'edge_to_next' in node_info:
|
| 1474 |
+
text += f" --[{node_info['edge_to_next']}]--> \n"
|
| 1475 |
+
|
| 1476 |
+
return {
|
| 1477 |
+
"source_id": source_id,
|
| 1478 |
+
"target_id": target_id,
|
| 1479 |
+
"path": path_details,
|
| 1480 |
+
"length": len(path) - 1,
|
| 1481 |
+
"text": text
|
| 1482 |
+
}
|
| 1483 |
+
|
| 1484 |
+
except nx.NetworkXNoPath:
|
| 1485 |
+
return {
|
| 1486 |
+
"source_id": source_id,
|
| 1487 |
+
"target_id": target_id,
|
| 1488 |
+
"path": [],
|
| 1489 |
+
"length": -1,
|
| 1490 |
+
"text": f"No path found between '{source_id}' and '{target_id}'"
|
| 1491 |
+
}
|
| 1492 |
+
except Exception as e:
|
| 1493 |
+
self.logger.error(f"Error finding path: {str(e)}")
|
| 1494 |
+
return {"error": f"Error finding path: {str(e)}"}
|
| 1495 |
+
|
| 1496 |
+
def get_subgraph(self, node_id: str, depth: int = 2, edge_types: Optional[List[str]] = None) -> dict:
|
| 1497 |
+
"""
|
| 1498 |
+
Extract a subgraph around a node up to a specified depth.
|
| 1499 |
+
|
| 1500 |
+
Args:
|
| 1501 |
+
node_id (str): The ID of the central node.
|
| 1502 |
+
depth (int): The depth/radius of the subgraph to extract. Defaults to 2.
|
| 1503 |
+
edge_types (Optional[List[str]]): Optional list of edge types to include (e.g., ['calls', 'contains']).
|
| 1504 |
+
|
| 1505 |
+
Returns:
|
| 1506 |
+
dict: A dictionary containing subgraph information or error message.
|
| 1507 |
+
"""
|
| 1508 |
+
self.logger.debug(f"Getting subgraph for node {node_id} with depth={depth}, edge_types={edge_types}")
|
| 1509 |
+
g = self.graph
|
| 1510 |
+
|
| 1511 |
+
if node_id not in g:
|
| 1512 |
+
return {"error": f"Node '{node_id}' not found."}
|
| 1513 |
+
|
| 1514 |
+
# Collect nodes within specified depth
|
| 1515 |
+
nodes_at_depth = {node_id}
|
| 1516 |
+
all_nodes = {node_id}
|
| 1517 |
+
|
| 1518 |
+
for d in range(depth):
|
| 1519 |
+
next_level = set()
|
| 1520 |
+
for n in nodes_at_depth:
|
| 1521 |
+
# Get all neighbors (both incoming and outgoing)
|
| 1522 |
+
for neighbor in g.successors(n):
|
| 1523 |
+
if edge_types is None:
|
| 1524 |
+
next_level.add(neighbor)
|
| 1525 |
+
else:
|
| 1526 |
+
edge_data = g.get_edge_data(n, neighbor)
|
| 1527 |
+
if edge_data and edge_data.get('relation') in edge_types:
|
| 1528 |
+
next_level.add(neighbor)
|
| 1529 |
+
|
| 1530 |
+
for neighbor in g.predecessors(n):
|
| 1531 |
+
if edge_types is None:
|
| 1532 |
+
next_level.add(neighbor)
|
| 1533 |
+
else:
|
| 1534 |
+
edge_data = g.get_edge_data(neighbor, n)
|
| 1535 |
+
if edge_data and edge_data.get('relation') in edge_types:
|
| 1536 |
+
next_level.add(neighbor)
|
| 1537 |
+
|
| 1538 |
+
nodes_at_depth = next_level - all_nodes
|
| 1539 |
+
all_nodes.update(next_level)
|
| 1540 |
+
|
| 1541 |
+
# Extract subgraph
|
| 1542 |
+
subgraph = g.subgraph(all_nodes).copy()
|
| 1543 |
+
|
| 1544 |
+
# Build node details
|
| 1545 |
+
nodes = []
|
| 1546 |
+
for n in subgraph.nodes():
|
| 1547 |
+
node = subgraph.nodes[n]['data']
|
| 1548 |
+
nodes.append({
|
| 1549 |
+
"node_id": n,
|
| 1550 |
+
"name": getattr(node, 'name', 'Unknown'),
|
| 1551 |
+
"type": getattr(node, 'node_type', 'Unknown')
|
| 1552 |
+
})
|
| 1553 |
+
|
| 1554 |
+
# Build edge details
|
| 1555 |
+
edges = []
|
| 1556 |
+
for source, target, data in subgraph.edges(data=True):
|
| 1557 |
+
edges.append({
|
| 1558 |
+
"source": source,
|
| 1559 |
+
"target": target,
|
| 1560 |
+
"relation": data.get('relation', 'Unknown')
|
| 1561 |
+
})
|
| 1562 |
+
|
| 1563 |
+
# Format text output
|
| 1564 |
+
text = f"Subgraph around '{node_id}' (depth: {depth}):\n"
|
| 1565 |
+
if edge_types:
|
| 1566 |
+
text += f"Edge types filter: {', '.join(edge_types)}\n"
|
| 1567 |
+
text += f"\nNodes: {len(nodes)}\n"
|
| 1568 |
+
text += f"Edges: {len(edges)}\n\n"
|
| 1569 |
+
|
| 1570 |
+
# Group nodes by type
|
| 1571 |
+
nodes_by_type = {}
|
| 1572 |
+
for node in nodes:
|
| 1573 |
+
node_type = node['type']
|
| 1574 |
+
if node_type not in nodes_by_type:
|
| 1575 |
+
nodes_by_type[node_type] = []
|
| 1576 |
+
nodes_by_type[node_type].append(node)
|
| 1577 |
+
|
| 1578 |
+
for node_type, type_nodes in nodes_by_type.items():
|
| 1579 |
+
text += f"{node_type} ({len(type_nodes)}):\n"
|
| 1580 |
+
for node in type_nodes[:5]:
|
| 1581 |
+
text += f" - {node['name']} ({node['node_id']})\n"
|
| 1582 |
+
if len(type_nodes) > 5:
|
| 1583 |
+
text += f" ... and {len(type_nodes) - 5} more\n"
|
| 1584 |
+
text += "\n"
|
| 1585 |
+
|
| 1586 |
+
# Show edge statistics
|
| 1587 |
+
edge_by_relation = {}
|
| 1588 |
+
for edge in edges:
|
| 1589 |
+
relation = edge['relation']
|
| 1590 |
+
edge_by_relation[relation] = edge_by_relation.get(relation, 0) + 1
|
| 1591 |
+
|
| 1592 |
+
if edge_by_relation:
|
| 1593 |
+
text += "Edge types:\n"
|
| 1594 |
+
for relation, count in edge_by_relation.items():
|
| 1595 |
+
text += f" - {relation}: {count}\n"
|
| 1596 |
+
|
| 1597 |
+
return {
|
| 1598 |
+
"center_node_id": node_id,
|
| 1599 |
+
"depth": depth,
|
| 1600 |
+
"edge_types_filter": edge_types,
|
| 1601 |
+
"node_count": len(nodes),
|
| 1602 |
+
"edge_count": len(edges),
|
| 1603 |
+
"nodes": nodes,
|
| 1604 |
+
"edges": edges,
|
| 1605 |
+
"nodes_by_type": nodes_by_type,
|
| 1606 |
+
"edge_by_relation": edge_by_relation,
|
| 1607 |
+
"text": text
|
| 1608 |
+
}
|
RepoKnowledgeGraphLib/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RepoKnowledgeGraphLib - Knowledge Graph Library for Code Repositories
|
| 3 |
+
|
| 4 |
+
This library provides tools for creating and querying knowledge graphs from code repositories.
|
| 5 |
+
"""
|
RepoKnowledgeGraphLib/utils/__init__.py
ADDED
|
File without changes
|
RepoKnowledgeGraphLib/utils/chunk_utils.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..Node import ChunkNode
|
| 2 |
+
from typing import List, Dict
|
| 3 |
+
|
| 4 |
+
def dict_to_chunknode(d: dict) -> ChunkNode:
|
| 5 |
+
"""
|
| 6 |
+
Converts a dictionary to a ChunkNode instance.
|
| 7 |
+
"""
|
| 8 |
+
return ChunkNode(**d)
|
| 9 |
+
|
| 10 |
+
def extract_filename_from_chunk(chunk:ChunkNode) -> str:
|
| 11 |
+
"""
|
| 12 |
+
Extracts the file name from a chunk.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
chunk (str): The chunk from which to extract the file name.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
str: The extracted file name.
|
| 19 |
+
"""
|
| 20 |
+
if isinstance(chunk, dict):
|
| 21 |
+
chunk = dict_to_chunknode(chunk)
|
| 22 |
+
return '_'.join(chunk.id.split('_')[:-1])
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def order_chunks_by_order_in_file(chunks:List[ChunkNode]) -> list:
|
| 26 |
+
"""
|
| 27 |
+
Orders a list of chunks by their order in the file.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
chunks (list): The list of chunks to order.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
list: The ordered list of chunks.
|
| 34 |
+
"""
|
| 35 |
+
# Convert dicts to ChunkNode if needed
|
| 36 |
+
chunks = [dict_to_chunknode(c) if isinstance(c, dict) else c for c in chunks]
|
| 37 |
+
return sorted(chunks, key=lambda x: int(x.order_in_file))
|
| 38 |
+
|
| 39 |
+
def organize_chunks_by_file_name(chunks: List[ChunkNode]) -> Dict[str, List[ChunkNode]]:
|
| 40 |
+
"""
|
| 41 |
+
Organizes a list of chunks by their file names.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
chunks (list): The list of chunks to organize.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
dict: A dictionary mapping file names to lists of chunks.
|
| 48 |
+
"""
|
| 49 |
+
# Convert dicts to ChunkNode if needed
|
| 50 |
+
chunks = [dict_to_chunknode(c) if isinstance(c, dict) else c for c in chunks]
|
| 51 |
+
organized_chunks = {}
|
| 52 |
+
for chunk in chunks:
|
| 53 |
+
file_name = extract_filename_from_chunk(chunk)
|
| 54 |
+
if file_name not in organized_chunks:
|
| 55 |
+
organized_chunks[file_name] = []
|
| 56 |
+
organized_chunks[file_name].append(chunk)
|
| 57 |
+
for file_name in organized_chunks:
|
| 58 |
+
organized_chunks[file_name] = order_chunks_by_order_in_file(organized_chunks[file_name])
|
| 59 |
+
return organized_chunks
|
| 60 |
+
|
| 61 |
+
def join_organized_chunks(organized_chunks: Dict[str, List[ChunkNode]]) -> str:
|
| 62 |
+
"""
|
| 63 |
+
Joins organized chunks into a single string.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
organized_chunks (dict): The dictionary of organized chunks.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
str: The joined string of organized chunks.
|
| 70 |
+
"""
|
| 71 |
+
joined_chunks_list = []
|
| 72 |
+
separator = "=" * 48
|
| 73 |
+
for filename in organized_chunks:
|
| 74 |
+
joined_chunks_list.append(separator)
|
| 75 |
+
joined_chunks_list.append(f"File: {filename}")
|
| 76 |
+
joined_chunks_list.append(separator)
|
| 77 |
+
# Convert dicts to ChunkNode if needed
|
| 78 |
+
chunks = [dict_to_chunknode(c) if isinstance(c, dict) else c for c in organized_chunks[filename]]
|
| 79 |
+
if len(chunks) == 0:
|
| 80 |
+
continue
|
| 81 |
+
if int(chunks[0].order_in_file) > 0:
|
| 82 |
+
joined_chunks_list.append("\n[...]")
|
| 83 |
+
for i, chunk in enumerate(chunks):
|
| 84 |
+
joined_chunks_list.append(chunk.content)
|
| 85 |
+
if i < len(chunks) - 1:
|
| 86 |
+
if int(chunks[i+1].order_in_file) - int(chunk.order_in_file) > 1:
|
| 87 |
+
joined_chunks_list.append("\n[...]")
|
| 88 |
+
return "\n".join(joined_chunks_list)
|
RepoKnowledgeGraphLib/utils/data_utils.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
def flatten_list(my_list: list) -> list:
|
| 3 |
+
"""
|
| 4 |
+
Args:
|
| 5 |
+
my_list: list composed of lists (of lists of lists...)
|
| 6 |
+
|
| 7 |
+
Returns: flattened list
|
| 8 |
+
|
| 9 |
+
"""
|
| 10 |
+
flattened_list = []
|
| 11 |
+
for item in my_list:
|
| 12 |
+
if isinstance(item, list) and len(item) > 0:
|
| 13 |
+
print(item)
|
| 14 |
+
flattened_list += flatten_list(item)
|
| 15 |
+
elif not isinstance(item, list):
|
| 16 |
+
flattened_list.append(item)
|
| 17 |
+
|
| 18 |
+
return flattened_list
|
RepoKnowledgeGraphLib/utils/logger_utils.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import atexit
|
| 5 |
+
|
| 6 |
+
# Global registry to track initialized loggers
|
| 7 |
+
_initialized_loggers = set()
|
| 8 |
+
|
| 9 |
+
# Get log level from environment variable (default to INFO for visibility in docker logs)
|
| 10 |
+
DEFAULT_LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper()
|
| 11 |
+
LOG_TO_FILE = os.getenv('LOG_TO_FILE', 'false').lower() == 'true'
|
| 12 |
+
|
| 13 |
+
def setup_logger(logger_name: str, log_file: str = '',
|
| 14 |
+
level: int = None) -> None:
|
| 15 |
+
"""
|
| 16 |
+
:param logger_name: name to give to logger
|
| 17 |
+
:param log_file: file to save log to
|
| 18 |
+
:param level: which base level of importance to set logger to (defaults to LOG_LEVEL env var)
|
| 19 |
+
:return: *None*
|
| 20 |
+
"""
|
| 21 |
+
# Check if logger has already been set up
|
| 22 |
+
if logger_name in _initialized_loggers:
|
| 23 |
+
return
|
| 24 |
+
|
| 25 |
+
log = logging.getLogger(logger_name)
|
| 26 |
+
|
| 27 |
+
# Determine log level from parameter, env var, or default
|
| 28 |
+
if level is None:
|
| 29 |
+
level = getattr(logging, DEFAULT_LOG_LEVEL, logging.INFO)
|
| 30 |
+
|
| 31 |
+
formatter = logging.Formatter(
|
| 32 |
+
fmt="%(name)s - %(levelname)s: %(asctime)-15s %(message)s")
|
| 33 |
+
|
| 34 |
+
# Always add stream handler for stdout (docker logs visibility)
|
| 35 |
+
stream_handler = logging.StreamHandler(sys.stdout)
|
| 36 |
+
stream_handler.setFormatter(formatter)
|
| 37 |
+
stream_handler.setLevel(level)
|
| 38 |
+
|
| 39 |
+
log.setLevel(level)
|
| 40 |
+
if not log.hasHandlers():
|
| 41 |
+
log.addHandler(stream_handler)
|
| 42 |
+
|
| 43 |
+
# Optionally add file handler if LOG_TO_FILE is enabled
|
| 44 |
+
if LOG_TO_FILE:
|
| 45 |
+
os.makedirs('logs', exist_ok=True)
|
| 46 |
+
if log_file == '':
|
| 47 |
+
log_file = f"{logger_name}.log"
|
| 48 |
+
log_file_path = os.path.join('logs', log_file)
|
| 49 |
+
file_handler = logging.FileHandler(log_file_path, mode='w')
|
| 50 |
+
file_handler.setFormatter(formatter)
|
| 51 |
+
file_handler.setLevel(level)
|
| 52 |
+
log.addHandler(file_handler)
|
| 53 |
+
|
| 54 |
+
# Prevent log propagation to avoid duplicate logs
|
| 55 |
+
log.propagate = False
|
| 56 |
+
|
| 57 |
+
# Mark this logger as initialized
|
| 58 |
+
_initialized_loggers.add(logger_name)
|
| 59 |
+
|
| 60 |
+
# Register cleanup function to close handlers on exit
|
| 61 |
+
atexit.register(_cleanup_logger, logger_name)
|
| 62 |
+
|
| 63 |
+
def _cleanup_logger(logger_name: str) -> None:
|
| 64 |
+
"""
|
| 65 |
+
Clean up logger handlers on program exit.
|
| 66 |
+
|
| 67 |
+
:param logger_name: name of the logger to clean up
|
| 68 |
+
"""
|
| 69 |
+
log = logging.getLogger(logger_name)
|
| 70 |
+
handlers = log.handlers[:]
|
| 71 |
+
for handler in handlers:
|
| 72 |
+
handler.close()
|
| 73 |
+
log.removeHandler(handler)
|
| 74 |
+
|
RepoKnowledgeGraphLib/utils/parsing_utils.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
def read_directory_files_recursively(directory_path: str, skip_dirs:list, skip_pattern: str = None) -> dict:
|
| 5 |
+
"""
|
| 6 |
+
Recursively reads all files in a directory and its subdirectories.
|
| 7 |
+
Skips files and directories that match the given regex pattern or are in skip_dirs.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
directory_path (str): The path to start reading files from.
|
| 11 |
+
skip_dirs (list): List of directory names to skip.
|
| 12 |
+
skip_pattern (str, optional): Regex pattern to skip files/directories.
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
dict: A dictionary where keys are relative file paths and values are file contents.
|
| 16 |
+
"""
|
| 17 |
+
file_contents = {}
|
| 18 |
+
compiled_pattern = re.compile(skip_pattern) if skip_pattern else None
|
| 19 |
+
|
| 20 |
+
for root, dirs, files in os.walk(directory_path):
|
| 21 |
+
# Skip directories listed in skip_dirs
|
| 22 |
+
dirs[:] = [d for d in dirs if d not in skip_dirs and not (compiled_pattern and compiled_pattern.search(os.path.join(root, d)))]
|
| 23 |
+
|
| 24 |
+
for file in files:
|
| 25 |
+
full_path = os.path.join(root, file)
|
| 26 |
+
relative_path = os.path.relpath(full_path, directory_path)
|
| 27 |
+
|
| 28 |
+
# Skip matching files
|
| 29 |
+
if compiled_pattern and compiled_pattern.search(relative_path):
|
| 30 |
+
continue
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
with open(full_path, 'r', encoding='utf-8') as f:
|
| 34 |
+
file_contents[relative_path] = f.read()
|
| 35 |
+
except (UnicodeDecodeError, OSError) as e:
|
| 36 |
+
print(f'Failed to read {relative_path}: {e}')
|
| 37 |
+
continue
|
| 38 |
+
#file_contents[relative_path] = f"<<Error reading file: {e}>>"
|
| 39 |
+
|
| 40 |
+
return file_contents
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_language_from_filename(file_name:str) -> str:
|
| 45 |
+
file_extension = file_name.split('.')[-1]
|
| 46 |
+
extension_mapping = {
|
| 47 |
+
'c': 'c',
|
| 48 |
+
'h': 'c',
|
| 49 |
+
'cpp': 'c++',
|
| 50 |
+
'cc': 'c++',
|
| 51 |
+
'cxx': 'c++',
|
| 52 |
+
'hpp': 'c++',
|
| 53 |
+
'hh': 'c++',
|
| 54 |
+
'hxx': 'c++',
|
| 55 |
+
'go': 'go',
|
| 56 |
+
'java': 'java',
|
| 57 |
+
'py': 'python',
|
| 58 |
+
'pyc': 'python',
|
| 59 |
+
'pyw':'python',
|
| 60 |
+
'js': 'javascript',
|
| 61 |
+
'mjs': 'javascript',
|
| 62 |
+
'cjs': 'javascript',
|
| 63 |
+
}
|
| 64 |
+
# Throws error if language not defined
|
| 65 |
+
return extension_mapping.get(file_extension, file_extension)
|
RepoKnowledgeGraphLib/utils/path_utils.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
import shutil
|
| 4 |
+
import zipfile
|
| 5 |
+
import tarfile
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _extract_zip(path: Path) -> str:
|
| 11 |
+
temp_dir = tempfile.mkdtemp()
|
| 12 |
+
with zipfile.ZipFile(path, 'r') as zip_ref:
|
| 13 |
+
zip_ref.extractall(temp_dir)
|
| 14 |
+
return temp_dir
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _extract_tgz(path: Path) -> str:
|
| 18 |
+
temp_dir = tempfile.mkdtemp()
|
| 19 |
+
with tarfile.open(path, 'r:gz') as tar_ref:
|
| 20 |
+
tar_ref.extractall(temp_dir)
|
| 21 |
+
return temp_dir
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def prepare_input_path(path: str) -> str:
|
| 25 |
+
"""Handles different input types: directories, files, zip or tgz archives."""
|
| 26 |
+
path_obj = Path(path)
|
| 27 |
+
if path_obj.is_dir():
|
| 28 |
+
return str(path_obj)
|
| 29 |
+
|
| 30 |
+
if path_obj.suffix == '.zip':
|
| 31 |
+
return _extract_zip(path_obj)
|
| 32 |
+
elif path_obj.suffix in {'.tgz', '.tar.gz'}:
|
| 33 |
+
return _extract_tgz(path_obj)
|
| 34 |
+
elif path_obj.is_file():
|
| 35 |
+
# Copy single file to a temporary directory
|
| 36 |
+
temp_dir = tempfile.mkdtemp()
|
| 37 |
+
shutil.copy(path_obj, temp_dir)
|
| 38 |
+
return temp_dir
|
| 39 |
+
else:
|
| 40 |
+
raise ValueError(f"Unsupported path type or extension: {path}")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def file_path_to_module_path(file_path: str) -> str:
|
| 44 |
+
"""
|
| 45 |
+
Convert a file path to a module path by replacing path separators with dots
|
| 46 |
+
and removing the file extension.
|
| 47 |
+
|
| 48 |
+
Examples:
|
| 49 |
+
path/to/repo/python_script.py -> path.to.repo.python_script
|
| 50 |
+
src/utils/helper.py -> src.utils.helper
|
| 51 |
+
module.py -> module
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
file_path: File path string
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Module path with dots instead of slashes
|
| 58 |
+
"""
|
| 59 |
+
# Normalize path separators
|
| 60 |
+
normalized = file_path.replace('\\', '/').replace(os.sep, '/')
|
| 61 |
+
|
| 62 |
+
# Remove file extension
|
| 63 |
+
without_ext = os.path.splitext(normalized)[0]
|
| 64 |
+
|
| 65 |
+
# Replace / with .
|
| 66 |
+
module_path = without_ext.replace('/', '.')
|
| 67 |
+
|
| 68 |
+
return module_path
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def generate_entity_aliases(entity_name: str, file_path: str) -> list:
|
| 72 |
+
"""
|
| 73 |
+
Generate all possible aliases for an entity based on its name and file path.
|
| 74 |
+
|
| 75 |
+
For example, if a file 'path/to/repo/python_script.py' defines 'Class_1',
|
| 76 |
+
the aliases would be:
|
| 77 |
+
- Class_1 (simple name)
|
| 78 |
+
- path.to.repo.python_script.Class_1 (fully qualified from file path)
|
| 79 |
+
|
| 80 |
+
For C++ namespaced entities like 'math::Calculator':
|
| 81 |
+
- math::Calculator (fully qualified name)
|
| 82 |
+
- Calculator (unqualified name, for use with 'using namespace')
|
| 83 |
+
- math.calculator.math::Calculator (module-based fully qualified)
|
| 84 |
+
|
| 85 |
+
For temporary paths like '.tmp.tmptqky4yk4..pyinstaller.run_astropy_tests.pos':
|
| 86 |
+
- pos (simple name)
|
| 87 |
+
- .run_astropy_tests.pos (progressive path removal)
|
| 88 |
+
- pyinstaller.run_astropy_tests.pos (further removal)
|
| 89 |
+
- .tmp.tmptqky4yk4..pyinstaller.run_astropy_tests.pos (full path)
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
entity_name: The name of the entity (e.g., 'Class_1', 'my_function', 'math::Calculator')
|
| 93 |
+
file_path: The file path where the entity is defined
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
List of alias strings
|
| 97 |
+
"""
|
| 98 |
+
aliases = []
|
| 99 |
+
|
| 100 |
+
# Always include the simple entity name
|
| 101 |
+
aliases.append(entity_name)
|
| 102 |
+
|
| 103 |
+
# For C++/C-style namespaced entities (using ::), add the unqualified name
|
| 104 |
+
if '::' in entity_name:
|
| 105 |
+
# Extract the unqualified name (last part after ::)
|
| 106 |
+
unqualified_name = entity_name.split('::')[-1]
|
| 107 |
+
if unqualified_name != entity_name:
|
| 108 |
+
aliases.append(unqualified_name)
|
| 109 |
+
|
| 110 |
+
# Generate module-based alias
|
| 111 |
+
module_path = file_path_to_module_path(file_path)
|
| 112 |
+
|
| 113 |
+
# If entity_name already contains scope separators (., ::),
|
| 114 |
+
# it might be a nested entity (e.g., 'MyClass.my_method')
|
| 115 |
+
# In this case, add the module path before the entire qualified name
|
| 116 |
+
fully_qualified = f"{module_path}.{entity_name}"
|
| 117 |
+
|
| 118 |
+
# Generate progressive path aliases by removing temporary/noise components
|
| 119 |
+
# Split the module path into components
|
| 120 |
+
components = module_path.split('.')
|
| 121 |
+
|
| 122 |
+
# Filter out components that look like temporary directories or UUIDs
|
| 123 |
+
def is_temp_component(component: str) -> bool:
|
| 124 |
+
"""Check if a path component looks like a temporary directory."""
|
| 125 |
+
if not component:
|
| 126 |
+
return True
|
| 127 |
+
# Check for common temp directory patterns
|
| 128 |
+
if component.startswith('tmp') and len(component) > 3:
|
| 129 |
+
return True
|
| 130 |
+
if component.startswith('.tmp'):
|
| 131 |
+
return True
|
| 132 |
+
# Check for UUID-like patterns (long alphanumeric strings)
|
| 133 |
+
if len(component) > 8 and component.replace('_', '').replace('-', '').isalnum():
|
| 134 |
+
# If it's mostly lowercase and has mix of letters and numbers, likely a temp ID
|
| 135 |
+
if sum(c.islower() for c in component) > len(component) / 2:
|
| 136 |
+
if sum(c.isdigit() for c in component) > 2:
|
| 137 |
+
return True
|
| 138 |
+
return False
|
| 139 |
+
|
| 140 |
+
# Generate aliases by progressively including more path components
|
| 141 |
+
# Start from the rightmost meaningful components and work backwards
|
| 142 |
+
clean_components = []
|
| 143 |
+
for component in components:
|
| 144 |
+
if not is_temp_component(component):
|
| 145 |
+
clean_components.append(component)
|
| 146 |
+
|
| 147 |
+
# Generate aliases with increasing path depth from meaningful components
|
| 148 |
+
if clean_components:
|
| 149 |
+
for i in range(1, len(clean_components) + 1):
|
| 150 |
+
# Take the last i components
|
| 151 |
+
partial_path = '.'.join(clean_components[-i:])
|
| 152 |
+
partial_alias = f".{partial_path}.{entity_name}"
|
| 153 |
+
if partial_alias != entity_name and partial_alias not in aliases:
|
| 154 |
+
aliases.append(partial_alias)
|
| 155 |
+
|
| 156 |
+
# Also add without leading dot for the full clean path
|
| 157 |
+
if i == len(clean_components):
|
| 158 |
+
no_dot_alias = f"{partial_path}.{entity_name}"
|
| 159 |
+
if no_dot_alias != entity_name and no_dot_alias not in aliases:
|
| 160 |
+
aliases.append(no_dot_alias)
|
| 161 |
+
|
| 162 |
+
# Always add the fully qualified path at the end (even if it contains temp components)
|
| 163 |
+
if fully_qualified != entity_name and fully_qualified not in aliases:
|
| 164 |
+
aliases.append(fully_qualified)
|
| 165 |
+
|
| 166 |
+
return aliases
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def normalize_include_path(include_path: str) -> str:
|
| 170 |
+
"""
|
| 171 |
+
Normalize an include path from #include directive to a module-like path.
|
| 172 |
+
|
| 173 |
+
Examples:
|
| 174 |
+
<vector> -> vector
|
| 175 |
+
<iostream> -> iostream
|
| 176 |
+
"myheader.h" -> myheader
|
| 177 |
+
"utils/helper.h" -> utils.helper
|
| 178 |
+
<boost/algorithm/string.hpp> -> boost.algorithm.string
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
include_path: The include path from #include directive
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Normalized module-like path
|
| 185 |
+
"""
|
| 186 |
+
# Remove angle brackets and quotes
|
| 187 |
+
path = include_path.strip('<>"')
|
| 188 |
+
|
| 189 |
+
# Convert to module path
|
| 190 |
+
module_path = file_path_to_module_path(path)
|
| 191 |
+
|
| 192 |
+
return module_path
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def build_entity_alias_map(entities: Dict[str, Dict]) -> Dict[str, str]:
|
| 196 |
+
"""
|
| 197 |
+
Build a mapping from all entity aliases to their canonical entity names.
|
| 198 |
+
This allows quick lookup when matching called entities to their definitions.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
entities: Dictionary of entity info keyed by canonical entity name
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
Dictionary mapping alias -> canonical entity name
|
| 205 |
+
"""
|
| 206 |
+
alias_map = {}
|
| 207 |
+
|
| 208 |
+
for entity_name, info in entities.items():
|
| 209 |
+
# Map the canonical name to itself
|
| 210 |
+
alias_map[entity_name] = entity_name
|
| 211 |
+
|
| 212 |
+
# Map all aliases to the canonical name
|
| 213 |
+
aliases = info.get('aliases', [])
|
| 214 |
+
for alias in aliases:
|
| 215 |
+
if alias and alias not in alias_map:
|
| 216 |
+
alias_map[alias] = entity_name
|
| 217 |
+
|
| 218 |
+
return alias_map
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def resolve_entity_call(called_name: str, alias_map: Dict[str, str],
|
| 222 |
+
imports: List[str] = None) -> Optional[str]:
|
| 223 |
+
"""
|
| 224 |
+
Resolve a called entity name to its canonical definition using aliases.
|
| 225 |
+
|
| 226 |
+
This handles cases like:
|
| 227 |
+
- Direct call: 'MyClass' -> 'MyClass'
|
| 228 |
+
- Qualified call: 'module.MyClass' -> 'MyClass' (if alias exists)
|
| 229 |
+
- Imported call: 'helper' -> 'utils.helper' (if imported)
|
| 230 |
+
- Simple name to qualified: 'Calculator' -> 'utils::Calculator'
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
called_name: The name of the called entity
|
| 234 |
+
alias_map: Mapping from aliases to canonical entity names
|
| 235 |
+
imports: List of import paths (optional, for context)
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
Canonical entity name if found, None otherwise
|
| 239 |
+
"""
|
| 240 |
+
# Don't try to resolve empty strings
|
| 241 |
+
if not called_name or not called_name.strip():
|
| 242 |
+
return None
|
| 243 |
+
|
| 244 |
+
# Direct match
|
| 245 |
+
if called_name in alias_map:
|
| 246 |
+
return alias_map[called_name]
|
| 247 |
+
|
| 248 |
+
# Try partial matches if imports are provided
|
| 249 |
+
if imports:
|
| 250 |
+
for import_path in imports:
|
| 251 |
+
# Try combining import path with called name
|
| 252 |
+
qualified = f"{import_path}.{called_name}"
|
| 253 |
+
if qualified in alias_map:
|
| 254 |
+
return alias_map[qualified]
|
| 255 |
+
|
| 256 |
+
# Try with :: separator (C++/Rust style)
|
| 257 |
+
qualified_cpp = f"{import_path}::{called_name}"
|
| 258 |
+
if qualified_cpp in alias_map:
|
| 259 |
+
return alias_map[qualified_cpp]
|
| 260 |
+
|
| 261 |
+
# Try fuzzy matching - look for canonical names that end with the called name
|
| 262 |
+
# This helps match 'Calculator' to 'utils::Calculator' or 'MyClass' to 'module.MyClass'
|
| 263 |
+
simple_name = extract_simple_name(called_name)
|
| 264 |
+
candidates = []
|
| 265 |
+
|
| 266 |
+
for alias, canonical in alias_map.items():
|
| 267 |
+
alias_simple = extract_simple_name(alias)
|
| 268 |
+
# If the simple names match, this could be a match
|
| 269 |
+
if alias_simple == simple_name:
|
| 270 |
+
candidates.append(canonical)
|
| 271 |
+
|
| 272 |
+
# If we found exactly one candidate, return it
|
| 273 |
+
if len(candidates) == 1:
|
| 274 |
+
return candidates[0]
|
| 275 |
+
|
| 276 |
+
# If we have multiple candidates, prefer the shortest qualified name
|
| 277 |
+
# (most likely to be the direct definition rather than an alias)
|
| 278 |
+
if len(candidates) > 1:
|
| 279 |
+
return min(candidates, key=lambda x: len(x))
|
| 280 |
+
|
| 281 |
+
return None
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def extract_simple_name(qualified_name: str) -> str:
|
| 285 |
+
"""
|
| 286 |
+
Extract the simple name from a qualified name.
|
| 287 |
+
|
| 288 |
+
Examples:
|
| 289 |
+
'namespace::MyClass' -> 'MyClass'
|
| 290 |
+
'module.MyClass' -> 'MyClass'
|
| 291 |
+
'MyClass' -> 'MyClass'
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
qualified_name: Fully or partially qualified name
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
Simple name without namespace/module prefix
|
| 298 |
+
"""
|
| 299 |
+
# Handle C++ style namespace separator
|
| 300 |
+
if '::' in qualified_name:
|
| 301 |
+
return qualified_name.split('::')[-1]
|
| 302 |
+
|
| 303 |
+
# Handle Python/JS style module separator
|
| 304 |
+
if '.' in qualified_name:
|
| 305 |
+
return qualified_name.split('.')[-1]
|
| 306 |
+
|
| 307 |
+
return qualified_name
|
| 308 |
+
|