lailaelkoussy commited on
Commit
a100cc5
Β·
verified Β·
1 Parent(s): b1ddffc

upload repoknowledgegraphlib

Browse files
RepoKnowledgeGraphLib/CodeIndex.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from tqdm import tqdm
3
+ import uuid
4
+ from typing import Literal
5
+ from abc import ABC, abstractmethod
6
+ import lancedb
7
+ import os
8
+ import numpy as np
9
+ import weaviate
10
+ from weaviate.classes.config import Configure, Property, DataType
11
+ from weaviate.classes.query import MetadataQuery
12
+
13
+ try:
14
+ LANCEDB_AVAILABLE = True
15
+ except ImportError:
16
+ LANCEDB_AVAILABLE = False
17
+
18
+ from .utils.logger_utils import setup_logger
19
+
20
+ LOGGER_NAME = 'CODE_INDEX_LOGGER'
21
+ STOP_AFTER_ATTEMPT = int(os.getenv("STOP_AFTER_ATTEMPT", 5))
22
+ WAIT_BETWEEN_RETRIES = int(os.getenv("WAIT_BETWEEN_RETRIES", 2))
23
+ MODEL_ID = os.getenv("MODEL_ID")
24
+ MAX_TOKENS = int(os.getenv('MAX_TOKENS', 2048))
25
+ TEMPERATURE = float(os.getenv('TEMPERATURE', 0.2))
26
+ TOP_P = float(os.getenv('TOP_P', 0.95))
27
+ FREQUENCY_PENALTY = 0
28
+ PRESENCE_PENALTY = 0
29
+ STOP = None
30
+ EMBEDDING_MODEL_URL = os.getenv('EMBEDDING_MODEL_URL')
31
+ EMBEDDING_MODEL_API_KEY = os.getenv('EMBEDDING_MODEL_API_KEY', "no_need")
32
+ EMBEDDING_NUMBER_DIMENSIONS = int(os.getenv('EMBEDDING_NUMBER_DIMENSIONS', 1024))
33
+
34
+ WEAVIATE_HOST = os.getenv('WEAVIATE_HOST', "localhost")
35
+ WEAVIATE_PORT = int(os.getenv('WEAVIATE_PORT', 8080))
36
+ WEAVIATE_GRPC_PORT = int(os.getenv('WEAVIATE_GRPC_PORT', 50051))
37
+ ALPHA_SEARCH_VALUE = float(os.getenv('ALPHA_SEARCH_VALUE', 0.8))
38
+ LANCEDB_PATH = os.getenv('LANCEDB_PATH', './local_code_index_db')
39
+
40
+
41
+ class BaseCodeIndex(ABC):
42
+ """Abstract base class for code indexing implementations"""
43
+
44
+ def __init__(self, nodes: list, model_service, index_type: Literal['embedding-only', 'keyword-only', 'hybrid'] = 'hybrid',
45
+ embedding_batch_size: int = 64, use_embed: bool = True):
46
+ setup_logger(LOGGER_NAME)
47
+ self.logger = logging.getLogger(LOGGER_NAME)
48
+ self.model_service = model_service
49
+ self.index_type = index_type
50
+ # Use larger batch size by default for better throughput
51
+ self.embedding_batch_size = int(os.getenv('EMBEDDING_BATCH_SIZE', embedding_batch_size))
52
+ self.use_embed = use_embed
53
+ self.logger.info(f"CodeIndex initialized with batch_size={self.embedding_batch_size}, index_type={index_type}")
54
+
55
+ @abstractmethod
56
+ def query(self, query: str, n_results: int=10) -> dict:
57
+ """Query the index and return results"""
58
+ pass
59
+
60
+ @abstractmethod
61
+ def __del__(self):
62
+ """Clean up resources"""
63
+ pass
64
+
65
+
66
+ class WeaviateCodeIndex(BaseCodeIndex):
67
+ """Weaviate-based code index implementation"""
68
+
69
+ def __init__(self, nodes: list, model_service, index_type: Literal['embedding-only', 'keyword-only', 'hybrid'] = 'hybrid',
70
+ embedding_batch_size: int = 20, use_embed: bool = True,
71
+ host: str = None, port: int = None, grpc_port: int = None):
72
+ super().__init__(nodes, model_service, index_type, embedding_batch_size, use_embed)
73
+
74
+ # Use provided parameters or fall back to environment variables
75
+ weaviate_host = host or WEAVIATE_HOST
76
+ weaviate_port = port or WEAVIATE_PORT
77
+ weaviate_grpc_port = grpc_port or WEAVIATE_GRPC_PORT
78
+
79
+ # Connect to Weaviate
80
+ self.weaviate_client = weaviate.connect_to_local(
81
+ host=weaviate_host,
82
+ port=weaviate_port,
83
+ grpc_port=weaviate_grpc_port
84
+ )
85
+
86
+ # Create a unique collection name
87
+ self.collection_name = f"CodeChunks_{str(uuid.uuid4()).replace('-', '_')}"
88
+
89
+ # Create collection with schema using the v4 API
90
+ # Use vector_config with Configure.Vectors.self_provided() - the modern approach
91
+ self.collection = self.weaviate_client.collections.create(
92
+ name=self.collection_name,
93
+ properties=[
94
+ Property(name="node_id", data_type=DataType.TEXT),
95
+ Property(name="name", data_type=DataType.TEXT),
96
+ Property(name="content", data_type=DataType.TEXT),
97
+ Property(name="description", data_type=DataType.TEXT),
98
+ Property(name="path", data_type=DataType.TEXT),
99
+ Property(name="language", data_type=DataType.TEXT),
100
+ Property(name="node_type", data_type=DataType.TEXT),
101
+ Property(name="order_in_file", data_type=DataType.INT),
102
+ Property(name="declared_entities", data_type=DataType.TEXT),
103
+ Property(name="called_entities", data_type=DataType.TEXT),
104
+ ],
105
+ # We provide our own vectors using the modern vector_config API
106
+ vector_config=Configure.Vectors.self_provided(),
107
+ )
108
+
109
+ chunk_nodes = [node for node in nodes if node.node_type == 'chunk']
110
+ self.logger.info(f"Weaviate indexing {len(chunk_nodes)} chunk nodes with batch_size={self.embedding_batch_size}")
111
+
112
+ # Pre-generate embeddings in batches for better performance
113
+ if self.index_type != 'keyword-only':
114
+ # Identify nodes that need embeddings
115
+ nodes_needing_embeddings = [
116
+ node for node in chunk_nodes
117
+ if node.embedding is None or (isinstance(node.embedding, (list,)) and len(node.embedding) == 0) or not use_embed
118
+ ]
119
+
120
+ if nodes_needing_embeddings:
121
+ total_batches = (len(nodes_needing_embeddings) + self.embedding_batch_size - 1) // self.embedding_batch_size
122
+ self.logger.info(f'Batch embedding {len(nodes_needing_embeddings)} nodes in {total_batches} batches')
123
+
124
+ # Process in batches
125
+ for i in tqdm(range(0, len(nodes_needing_embeddings), self.embedding_batch_size),
126
+ desc="Batch embedding nodes"):
127
+ batch_nodes = nodes_needing_embeddings[i:i + self.embedding_batch_size]
128
+ texts_to_embed = [node.get_field_to_embed() for node in batch_nodes]
129
+
130
+ # Batch embed all texts
131
+ embeddings = self.model_service.embed_chunk_code_batch(texts_to_embed)
132
+
133
+ # Assign embeddings back to nodes
134
+ for node, embedding in zip(batch_nodes, embeddings):
135
+ node.embedding = embedding
136
+
137
+ # Log progress every 10 batches
138
+ batch_num = i // self.embedding_batch_size + 1
139
+ if batch_num % 10 == 0:
140
+ self.logger.info(f"Completed batch {batch_num}/{total_batches}")
141
+
142
+ self.logger.info(f"Embedding complete: processed {len(nodes_needing_embeddings)} nodes")
143
+ else:
144
+ self.logger.info(f"Using existing embeddings for all {len(chunk_nodes)} nodes")
145
+
146
+ # Batch insert data into Weaviate
147
+ with self.collection.batch.dynamic() as batch:
148
+ for node in tqdm(chunk_nodes, desc="Indexing nodes"):
149
+ self.logger.debug(f'Indexing node : {node.id}')
150
+
151
+ # Use pre-computed embedding
152
+ embedding = None
153
+ if self.index_type != 'keyword-only':
154
+ embedding = node.embedding
155
+
156
+ # Prepare properties
157
+ properties = {
158
+ "node_id": node.id,
159
+ "name": node.name,
160
+ "content": node.content,
161
+ "description": node.description or "",
162
+ "path": node.path,
163
+ "language": node.language,
164
+ "node_type": node.node_type,
165
+ "order_in_file": node.order_in_file,
166
+ "declared_entities": str(node.declared_entities),
167
+ "called_entities": str(node.called_entities),
168
+ }
169
+
170
+ # Add object with or without vector based on index_type
171
+ if self.index_type == 'keyword-only':
172
+ # No vector needed for keyword-only search
173
+ batch.add_object(properties=properties)
174
+ else:
175
+ # Add with vector for embedding-only and hybrid modes
176
+ batch.add_object(
177
+ properties=properties,
178
+ vector=embedding
179
+ )
180
+
181
+
182
+ def query(self, query: str, n_results:int=10) -> dict:
183
+ """
184
+ Perform search based on index_type:
185
+ - 'embedding-only': pure vector search
186
+ - 'keyword-only': pure keyword search (BM25)
187
+ - 'hybrid': hybrid search combining both (alpha controls weighting)
188
+
189
+ Weaviate's hybrid search uses:
190
+ - alpha=0: pure keyword search (BM25)
191
+ - alpha=1: pure vector search
192
+ - alpha=0.5-0.8: balanced hybrid search (recommended)
193
+ """
194
+ try:
195
+ # Execute search based on index_type
196
+ if self.index_type == 'keyword-only':
197
+ # Pure BM25 keyword search
198
+ response = self.collection.query.bm25(
199
+ query=query,
200
+ limit=n_results,
201
+ return_metadata=MetadataQuery(score=True)
202
+ )
203
+ elif self.index_type == 'embedding-only':
204
+ # Pure vector search
205
+ embedding = self.model_service.embed_query(query)
206
+ response = self.collection.query.near_vector(
207
+ near_vector=embedding,
208
+ limit=n_results,
209
+ return_metadata=MetadataQuery(distance=True)
210
+ )
211
+ else: # 'hybrid'
212
+ # Hybrid search combining keyword and vector
213
+ embedding = self.model_service.embed_query(query)
214
+ response = self.collection.query.hybrid(
215
+ query=query,
216
+ vector=embedding,
217
+ limit=n_results,
218
+ alpha=ALPHA_SEARCH_VALUE,
219
+ return_metadata=MetadataQuery(distance=True, score=True)
220
+ )
221
+
222
+ # Convert to ChromaDB-like format for compatibility
223
+ results = {
224
+ 'ids': [[]],
225
+ 'distances': [[]],
226
+ 'metadatas': [[]],
227
+ 'documents': [[]]
228
+ }
229
+
230
+ for obj in response.objects:
231
+ results['ids'][0].append(obj.properties['node_id'])
232
+ results['distances'][0].append(obj.metadata.distance if obj.metadata.distance else 0.0)
233
+ results['metadatas'][0].append({
234
+ 'id': obj.properties['node_id'],
235
+ 'name': obj.properties['name'],
236
+ 'content': obj.properties['content'],
237
+ 'description': obj.properties['description'],
238
+ 'path': obj.properties['path'],
239
+ 'language': obj.properties['language'],
240
+ 'node_type': obj.properties['node_type'],
241
+ 'order_in_file': str(obj.properties['order_in_file']),
242
+ 'declared_entities': obj.properties['declared_entities'],
243
+ 'called_entities': obj.properties['called_entities'],
244
+ })
245
+ results['documents'][0].append(obj.properties['content'])
246
+
247
+ return results
248
+
249
+ except Exception as e:
250
+ self.logger.error(f'Failed to query: {e}', exc_info=True)
251
+ raise e
252
+
253
+ def __del__(self):
254
+ """Clean up Weaviate connection"""
255
+ if hasattr(self, 'weaviate_client'):
256
+ try:
257
+ self.weaviate_client.close()
258
+ except:
259
+ pass
260
+
261
+
262
+ class LanceDBCodeIndex(BaseCodeIndex):
263
+ """LanceDB-based code index implementation"""
264
+
265
+ def __init__(self, nodes: list, model_service, index_type: Literal['embedding-only', 'keyword-only', 'hybrid'] = 'hybrid',
266
+ embedding_batch_size: int = 20, use_embed: bool = True, db_path: str = None):
267
+ super().__init__(nodes, model_service, index_type, embedding_batch_size, use_embed)
268
+
269
+ if not LANCEDB_AVAILABLE:
270
+ raise ImportError("LanceDB is not available. Please install it with: pip install lancedb")
271
+
272
+ # Embedded DB
273
+ self.db_path = db_path or LANCEDB_PATH
274
+ self.db = lancedb.connect(self.db_path)
275
+ self.table_name = f"code_chunks_{uuid.uuid4().hex}"
276
+ self.table = None
277
+
278
+ chunk_nodes = [node for node in nodes if node.node_type == "chunk"]
279
+ self.logger.info(f"LanceDB indexing {len(chunk_nodes)} chunk nodes with batch_size={self.embedding_batch_size}")
280
+
281
+ # -----------------------------------------------------------
282
+ # Create embeddings IF using vector search
283
+ # -----------------------------------------------------------
284
+ if self.index_type != "keyword-only":
285
+ # Find nodes that need embeddings
286
+ # use_embed=True means we should USE existing embeddings if available
287
+ # use_embed=False means we should regenerate all embeddings
288
+ nodes_needing_embeddings = []
289
+ for node in chunk_nodes:
290
+ needs_embedding = False
291
+ if not use_embed:
292
+ # Regenerate all embeddings
293
+ needs_embedding = True
294
+ elif node.embedding is None:
295
+ needs_embedding = True
296
+ elif isinstance(node.embedding, (list, np.ndarray)) and len(node.embedding) == 0:
297
+ needs_embedding = True
298
+
299
+ if needs_embedding:
300
+ nodes_needing_embeddings.append(node)
301
+
302
+ if nodes_needing_embeddings:
303
+ total_batches = (len(nodes_needing_embeddings) + self.embedding_batch_size - 1) // self.embedding_batch_size
304
+ self.logger.info(f"Embedding {len(nodes_needing_embeddings)} chunks in {total_batches} batches (batch_size={self.embedding_batch_size})...")
305
+
306
+ for i in tqdm(range(0, len(nodes_needing_embeddings), self.embedding_batch_size),
307
+ desc="Batch embedding nodes"):
308
+ batch = nodes_needing_embeddings[i:i + self.embedding_batch_size]
309
+ texts = [n.get_field_to_embed() for n in batch]
310
+ batch_embeds = self.model_service.embed_chunk_code_batch(texts)
311
+
312
+ for n, emb in zip(batch, batch_embeds):
313
+ n.embedding = np.array(emb, dtype=np.float32)
314
+
315
+ # Log progress every 10 batches
316
+ batch_num = i // self.embedding_batch_size + 1
317
+ if batch_num % 10 == 0:
318
+ self.logger.info(f"Completed batch {batch_num}/{total_batches}")
319
+
320
+ self.logger.info(f"Embedding complete: processed {len(nodes_needing_embeddings)} chunks")
321
+ else:
322
+ self.logger.info(f"Using existing embeddings for all {len(chunk_nodes)} chunks")
323
+
324
+ # -----------------------------------------------------------
325
+ # Prepare rows (only include vector column when allowed)
326
+ # -----------------------------------------------------------
327
+ rows = []
328
+ for node in chunk_nodes:
329
+ row = {
330
+ "node_id": node.id,
331
+ "name": node.name,
332
+ "content": node.content,
333
+ "description": node.description or "",
334
+ "path": node.path,
335
+ "language": node.language,
336
+ "node_type": node.node_type,
337
+ "order_in_file": node.order_in_file,
338
+ "declared_entities": str(node.declared_entities),
339
+ "called_entities": str(node.called_entities),
340
+ }
341
+
342
+ # Add embeddings only for hybrid/embedding-only
343
+ if self.index_type != "keyword-only":
344
+ row["vector"] = node.embedding
345
+
346
+ rows.append(row)
347
+
348
+ # Create table
349
+ self.table = self.db.create_table(self.table_name, data=rows)
350
+ self.logger.info(f"Created LanceDB table: {self.table_name}")
351
+
352
+ # Create full-text search index for keyword and hybrid search
353
+ # LanceDB requires INVERTED index for full-text search
354
+ self._create_fts_indexes()
355
+
356
+ def _create_fts_indexes(self):
357
+ """
358
+ Create full-text search indexes on text columns.
359
+
360
+ LanceDB 0.25.x uses create_fts_index() with use_tantivy=True to support
361
+ multiple columns. Requires tantivy package: pip install tantivy
362
+ """
363
+ fts_columns = ["content", "name", "description"]
364
+
365
+ try:
366
+ # use_tantivy=True is required to support multiple field names as a list
367
+ self.table.create_fts_index(fts_columns, replace=True, use_tantivy=True)
368
+ self.logger.info(f"Created FTS index (Tantivy) on columns: {fts_columns}")
369
+ except Exception as e:
370
+ self.logger.warning(f"Failed to create FTS index: {e}")
371
+ self.logger.warning(
372
+ "Full-text search will fall back to scanning. "
373
+ "Ensure tantivy is installed: pip install tantivy"
374
+ )
375
+
376
+ def query(self, query: str, n_results: int=10) -> dict:
377
+ """
378
+ Perform search based on index_type:
379
+ - 'embedding-only': pure vector search
380
+ - 'keyword-only': full-text search using LanceDB's native FTS
381
+ - 'hybrid': combines vector similarity and full-text search with reranking
382
+ """
383
+ try:
384
+ # ---------------------- KEYWORD ONLY ----------------------
385
+ if self.index_type == "keyword-only":
386
+ # Use LanceDB full-text search (requires FTS index on the table)
387
+ try:
388
+ # Try full-text search first
389
+ df = self.table.search(query, query_type="fts").limit(n_results).to_pandas()
390
+ except Exception as fts_error:
391
+ self.logger.warning(f"FTS search failed, falling back to scan: {fts_error}")
392
+ # Fallback: scan all rows and filter in Python
393
+ all_df = self.table.to_pandas()
394
+ query_lower = query.lower()
395
+ # Split query into words for more flexible matching
396
+ query_words = query_lower.split()
397
+
398
+ def matches_query(row):
399
+ text = f"{row.get('content', '')} {row.get('name', '')} {row.get('description', '')}".lower()
400
+ # Match if any query word is found
401
+ return any(word in text for word in query_words)
402
+
403
+ mask = all_df.apply(matches_query, axis=1)
404
+ df = all_df[mask].head(n_results)
405
+ # Add a dummy distance column
406
+ df = df.copy()
407
+ df['_distance'] = 0.0
408
+
409
+ # ---------------------- VECTOR ONLY -----------------------
410
+ elif self.index_type == "embedding-only":
411
+ emb = np.array(self.model_service.embed_query(query), dtype=np.float32)
412
+ df = self.table.search(
413
+ emb,
414
+ vector_column_name="vector"
415
+ ).limit(n_results).to_pandas()
416
+
417
+ # ---------------------- HYBRID ----------------------------
418
+ else:
419
+ # For hybrid search, we do vector search and optionally boost results
420
+ # that also match keywords. This is more flexible than requiring both.
421
+ emb = np.array(self.model_service.embed_query(query), dtype=np.float32)
422
+
423
+ # Get more results from vector search to allow for reranking
424
+ vector_limit = min(n_results * 3, 100) # Get 3x results for reranking
425
+ df = self.table.search(
426
+ emb,
427
+ vector_column_name="vector"
428
+ ).limit(vector_limit).to_pandas()
429
+
430
+ if not df.empty:
431
+ # Rerank results based on keyword matches
432
+ query_lower = query.lower()
433
+ query_words = query_lower.split()
434
+
435
+ def compute_keyword_score(row):
436
+ """Compute a keyword match score (higher is better)"""
437
+ text = f"{row.get('content', '')} {row.get('name', '')} {row.get('description', '')}".lower()
438
+ score = 0
439
+ # Exact phrase match gets highest score
440
+ if query_lower in text:
441
+ score += 10
442
+ # Word matches
443
+ for word in query_words:
444
+ if word in text:
445
+ score += 1
446
+ # Bonus for word in name (more relevant)
447
+ if word in str(row.get('name', '')).lower():
448
+ score += 2
449
+ return score
450
+
451
+ # Add keyword scores
452
+ df = df.copy()
453
+ df['_keyword_score'] = df.apply(compute_keyword_score, axis=1)
454
+
455
+ # Normalize distance to a similarity score (lower distance = higher similarity)
456
+ max_dist = df['_distance'].max() if df['_distance'].max() > 0 else 1.0
457
+ df['_vector_score'] = 1.0 - (df['_distance'] / max_dist)
458
+
459
+ # Combined score: weighted sum of vector similarity and keyword score
460
+ # Alpha controls the balance (higher alpha = more weight on vector search)
461
+ alpha = 0.7 # 70% vector, 30% keyword
462
+ max_keyword = df['_keyword_score'].max() if df['_keyword_score'].max() > 0 else 1.0
463
+ df['_combined_score'] = (
464
+ alpha * df['_vector_score'] +
465
+ (1 - alpha) * (df['_keyword_score'] / max_keyword)
466
+ )
467
+
468
+ # Sort by combined score (descending) and take top n_results
469
+ df = df.sort_values('_combined_score', ascending=False).head(n_results)
470
+
471
+ # Build result format (ChromaDB-like format for compatibility)
472
+ results = {
473
+ "ids": [[]],
474
+ "distances": [[]],
475
+ "metadatas": [[]],
476
+ "documents": [[]],
477
+ }
478
+
479
+ for _, row in df.iterrows():
480
+ results["ids"][0].append(row["node_id"])
481
+ results["documents"][0].append(row["content"])
482
+ results["distances"][0].append(float(row.get("_distance", 0)))
483
+
484
+ results["metadatas"][0].append({
485
+ "id": row["node_id"],
486
+ "name": row["name"],
487
+ "content": row["content"],
488
+ "description": row["description"],
489
+ "path": row["path"],
490
+ "language": row["language"],
491
+ "node_type": row["node_type"],
492
+ "order_in_file": str(row["order_in_file"]),
493
+ "declared_entities": row["declared_entities"],
494
+ "called_entities": row["called_entities"],
495
+ })
496
+
497
+ return results
498
+
499
+ except Exception as e:
500
+ self.logger.error(f"Query failed: {e}", exc_info=True)
501
+ raise
502
+
503
+ def __del__(self):
504
+ """Clean up resources"""
505
+ pass
506
+
507
+
508
+ # Factory function to create the appropriate CodeIndex
509
+ def CodeIndex(
510
+ nodes: list,
511
+ model_service,
512
+ index_type: Literal['embedding-only', 'keyword-only', 'hybrid'] = 'hybrid',
513
+ embedding_batch_size: int = 20,
514
+ use_embed: bool = True,
515
+ backend: Literal['weaviate', 'lancedb'] = 'weaviate',
516
+ db_path: str = None,
517
+ host: str = None,
518
+ port: int = None,
519
+ grpc_port: int = None
520
+ ) -> BaseCodeIndex:
521
+ """
522
+ Factory function to create a CodeIndex instance.
523
+
524
+ Args:
525
+ nodes: List of nodes to index
526
+ model_service: Service for embedding generation
527
+ index_type: Type of search ('embedding-only', 'keyword-only', or 'hybrid')
528
+ embedding_batch_size: Batch size for embedding generation
529
+ use_embed: Whether to use pre-computed embeddings
530
+ backend: Which backend to use ('weaviate' or 'lancedb')
531
+ db_path: Path for LanceDB (only used with 'lancedb' backend)
532
+ host: Weaviate host (only used with 'weaviate' backend)
533
+ port: Weaviate port (only used with 'weaviate' backend)
534
+ grpc_port: Weaviate gRPC port (only used with 'weaviate' backend)
535
+
536
+ Returns:
537
+ BaseCodeIndex: Either WeaviateCodeIndex or LanceDBCodeIndex instance
538
+ """
539
+ if backend == 'lancedb':
540
+ return LanceDBCodeIndex(
541
+ nodes=nodes,
542
+ model_service=model_service,
543
+ index_type=index_type,
544
+ embedding_batch_size=embedding_batch_size,
545
+ use_embed=use_embed,
546
+ db_path=db_path
547
+ )
548
+ elif backend == 'weaviate':
549
+ return WeaviateCodeIndex(
550
+ nodes=nodes,
551
+ model_service=model_service,
552
+ index_type=index_type,
553
+ embedding_batch_size=embedding_batch_size,
554
+ use_embed=use_embed,
555
+ host=host,
556
+ port=port,
557
+ grpc_port=grpc_port
558
+ )
559
+ else: # default to weaviate
560
+ return WeaviateCodeIndex(
561
+ nodes=nodes,
562
+ model_service=model_service,
563
+ index_type=index_type,
564
+ embedding_batch_size=embedding_batch_size,
565
+ use_embed=use_embed,
566
+ host=host,
567
+ port=port,
568
+ grpc_port=grpc_port
569
+ )
570
+
571
+
RepoKnowledgeGraphLib/CodeParser.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langchain_text_splitters import (
5
+ Language,
6
+ RecursiveCharacterTextSplitter,
7
+ )
8
+
9
+ from .utils.logger_utils import setup_logger
10
+ load_dotenv()
11
+
12
+
13
+ LOGGER_NAME = 'CODE_PARSER_LOGGER'
14
+ CODE_CHUNK_OVERLAP = int(os.getenv('CODE_CHUNK_OVERLAP', 0))
15
+ CODE_CHUNK_SIZE = int(os.getenv('CODE_CHUNK_SIZE', 2000))
16
+
17
+
18
+ class CodeParser:
19
+ def __init__(self):
20
+ setup_logger(LOGGER_NAME)
21
+ self.logger = logging.getLogger(LOGGER_NAME)
22
+
23
+ self.extension_mapping = {
24
+ 'c': Language.C,
25
+ 'h': Language.C,
26
+ 'cpp': Language.CPP,
27
+ 'cc': Language.CPP,
28
+ 'cxx': Language.CPP,
29
+ 'hpp': Language.CPP,
30
+ 'hh': Language.CPP,
31
+ 'hxx': Language.CPP,
32
+ 'go': Language.GO,
33
+ 'java': Language.JAVA,
34
+ 'py': Language.PYTHON,
35
+ 'pyw': Language.PYTHON,
36
+ 'js': Language.JS,
37
+ 'mjs': Language.JS,
38
+ 'cjs': Language.JS,
39
+ 'md': Language.MARKDOWN,
40
+ 'markdown': Language.MARKDOWN,
41
+ 'html': Language.HTML,
42
+ }
43
+
44
+ def parse(self, file_name:str, file_content:str) -> list:
45
+ file_extension = file_name.split('.')[-1]
46
+
47
+ try:
48
+ self.logger.debug(f'Parsing file: {file_name}')
49
+ if file_extension not in self.extension_mapping:
50
+ self.logger.debug(f'File extension not supported: {file_extension}')
51
+ text_splitter = RecursiveCharacterTextSplitter(
52
+ chunk_size=CODE_CHUNK_SIZE,
53
+ chunk_overlap=CODE_CHUNK_OVERLAP,
54
+ length_function=len,
55
+ is_separator_regex=False,
56
+ )
57
+ docs = text_splitter.create_documents([file_content])
58
+
59
+ else:
60
+ self.logger.debug(f'File extension supported: {file_extension}')
61
+ code_splitter = RecursiveCharacterTextSplitter.from_language(language=self.extension_mapping[file_extension], chunk_size=CODE_CHUNK_SIZE, chunk_overlap=CODE_CHUNK_OVERLAP)
62
+ docs = code_splitter.create_documents([file_content])
63
+ except Exception as e:
64
+ self.logger.error(f'Error when parsing code: {e}')
65
+ return [doc.page_content for doc in docs]
66
+
67
+
68
+
69
+
70
+
RepoKnowledgeGraphLib/Entity.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, List, Type, Any
2
+ from dataclasses import dataclass, field, asdict, fields, is_dataclass
3
+
4
+ # Helper for dynamic class lookup
5
+ ENTITY_TYPE_MAP = {}
6
+
7
+ def register_entity(cls):
8
+ ENTITY_TYPE_MAP[cls.__name__] = cls
9
+ return cls
10
+
11
+ def _entity_to_dict(obj):
12
+ if isinstance(obj, list):
13
+ return [_entity_to_dict(item) for item in obj]
14
+ elif isinstance(obj, dict):
15
+ return {(_entity_to_dict(k) if isinstance(k, Entity) else k): _entity_to_dict(v) for k, v in obj.items()}
16
+ elif isinstance(obj, Entity):
17
+ return obj.to_dict()
18
+ elif hasattr(obj, 'to_dict'):
19
+ return obj.to_dict()
20
+ else:
21
+ return obj
22
+
23
+ def _entity_from_dict(data):
24
+ if isinstance(data, list):
25
+ return [_entity_from_dict(item) for item in data]
26
+ elif isinstance(data, dict) and 'entity_type' in data:
27
+ cls = ENTITY_TYPE_MAP.get(data['entity_type'].capitalize(), Entity)
28
+ return cls.from_dict(data)
29
+ else:
30
+ return data
31
+
32
+ @register_entity
33
+ @dataclass
34
+ class Entity:
35
+ entity_type: str
36
+ entity_name: str
37
+ defined_chunk_id: str
38
+ entity_dtype: str
39
+
40
+ def to_dict(self):
41
+ d = asdict(self)
42
+ d['entity_type'] = self.entity_type
43
+ d['__class__'] = self.__class__.__name__
44
+ return d
45
+
46
+ @classmethod
47
+ def from_dict(cls, data):
48
+ # Remove __class__ if present
49
+ data = dict(data)
50
+ data.pop('__class__', None)
51
+ return cls(**data)
52
+
53
+ @register_entity
54
+ @dataclass
55
+ class Variable(Entity):
56
+ entity_type = 'variable'
57
+
58
+ def to_dict(self):
59
+ d = super().to_dict()
60
+ d['entity_type'] = self.entity_type
61
+ return d
62
+
63
+ @classmethod
64
+ def from_dict(cls, data):
65
+ return super().from_dict(data)
66
+
67
+ @register_entity
68
+ @dataclass
69
+ class Parameter(Entity):
70
+ entity_type = 'parameter'
71
+ entity_dtype: str
72
+
73
+ def to_dict(self):
74
+ d = super().to_dict()
75
+ d['entity_type'] = self.entity_type
76
+ return d
77
+
78
+ @classmethod
79
+ def from_dict(cls, data):
80
+ return super().from_dict(data)
81
+
82
+ @register_entity
83
+ @dataclass
84
+ class Method(Entity):
85
+ entity_type = 'method'
86
+ parameters: List['Parameter'] = field(default_factory=list)
87
+ associated_class: Optional['Class'] = None
88
+
89
+ def to_dict(self):
90
+ d = super().to_dict()
91
+ d['parameters'] = _entity_to_dict(self.parameters)
92
+ d['associated_class'] = self.associated_class.to_dict() if self.associated_class else None
93
+ d['entity_type'] = self.entity_type
94
+ return d
95
+
96
+ @classmethod
97
+ def from_dict(cls, data):
98
+ params = [_entity_from_dict(p) for p in data.get('parameters', [])]
99
+ assoc_cls = Class.from_dict(data['associated_class']) if data.get('associated_class') else None
100
+ base = {k: v for k, v in data.items() if k not in ['parameters', 'parameters_pairs', 'associated_class']}
101
+ return cls(parameters=params, associated_class=assoc_cls, **base)
102
+
103
+ @register_entity
104
+ @dataclass
105
+ class Class(Entity):
106
+ entity_type = 'class'
107
+ defined_methods: List['Method'] = field(default_factory=list)
108
+
109
+ def to_dict(self):
110
+ d = super().to_dict()
111
+ d['defined_methods'] = _entity_to_dict(self.defined_methods)
112
+ d['entity_type'] = self.entity_type
113
+ return d
114
+
115
+ @classmethod
116
+ def from_dict(cls, data):
117
+ methods = [_entity_from_dict(m) for m in data.get('defined_methods', [])]
118
+ base = {k: v for k, v in data.items() if k != 'defined_methods'}
119
+ return cls(defined_methods=methods, **base)
120
+
121
+ @register_entity
122
+ @dataclass
123
+ class Function(Entity):
124
+ entity_type = 'function'
125
+ parameters: List[Parameter] = field(default_factory=list)
126
+ parameters_pairs: List[tuple] = field(default_factory=list) # List of (Parameter, Variable)
127
+
128
+ def to_dict(self):
129
+ d = super().to_dict()
130
+ d['parameters'] = _entity_to_dict(self.parameters)
131
+ d['parameters_pairs'] = [ (p.to_dict(), v.to_dict()) for p, v in self.parameters_pairs ]
132
+ d['entity_type'] = self.entity_type
133
+ return d
134
+
135
+ @classmethod
136
+ def from_dict(cls, data):
137
+ params = [_entity_from_dict(p) for p in data.get('parameters', [])]
138
+ parameters_pairs = [(Parameter.from_dict(p), Variable.from_dict(v)) for p, v in data.get('parameters_pairs', [])]
139
+ base = {k: v for k, v in data.items() if k not in ['parameters', 'parameters_pairs']}
140
+ return cls(parameters=params, parameters_pairs=parameters_pairs, **base)
141
+
142
+ @register_entity
143
+ @dataclass
144
+ class FunctionCall(Entity):
145
+ entity_type: str = 'function_call'
146
+ entity_name: str = ''
147
+ defined_chunk_id: str = ''
148
+ entity_dtype: str = ''
149
+ arguments: List[tuple] = field(default_factory=list) # List of (Parameter, Variable)
150
+ associated_functions: Optional[Function] = field(default_factory=list)
151
+
152
+
153
+ def to_dict(self):
154
+ d = super().to_dict()
155
+ d['arguments'] = [ (p.to_dict(), v.to_dict()) for p, v in self.arguments ]
156
+ d['entity_type'] = self.entity_type
157
+ return d
158
+
159
+ @classmethod
160
+ def from_dict(cls, data):
161
+ arguments = [(Parameter.from_dict(p), Variable.from_dict(v)) for p, v in data.get('arguments', [])]
162
+ base = {k: v for k, v in data.items() if k != 'arguments'}
163
+ return cls(arguments=arguments, **base)
164
+
165
+ @register_entity
166
+ @dataclass
167
+ class MethodCall(Entity):
168
+ entity_type: str = 'method_call'
169
+ entity_name: str = ''
170
+ defined_chunk_id: str = ''
171
+ entity_dtype: str = ''
172
+ arguments: List[tuple] = field(default_factory=list) # List of (Parameter, Variable)
173
+ associated_class: Optional[Class] = None
174
+ associated_method: Optional[Method] = None
175
+
176
+ def to_dict(self):
177
+ d = super().to_dict()
178
+ d['arguments'] = [ (p.to_dict(), v.to_dict()) for p, v in self.arguments ]
179
+ d['associated_class'] = self.associated_class.to_dict() if self.associated_class else None
180
+ d['entity_type'] = self.entity_type
181
+ return d
182
+
183
+ @classmethod
184
+ def from_dict(cls, data):
185
+ arguments = [(Parameter.from_dict(p), Variable.from_dict(v)) for p, v in data.get('arguments', [])]
186
+ assoc_cls = Class.from_dict(data['associated_class']) if data.get('associated_class') else None
187
+ base = {k: v for k, v in data.items() if k not in ['arguments', 'associated_class']}
188
+ return cls(arguments=arguments, associated_class=assoc_cls, **base)
RepoKnowledgeGraphLib/EntityChunkMapper.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ from typing import List, Tuple, Dict, Any, Set, Optional
4
+ from enum import Enum
5
+
6
+
7
+ class Language(Enum):
8
+ """Supported programming languages"""
9
+ PYTHON = "python"
10
+ C = "c"
11
+ CPP = "cpp"
12
+ JAVA = "java"
13
+
14
+
15
+ class EntityChunkMapper:
16
+ """Maps entities from file-level extraction back to their respective chunks"""
17
+
18
+ def __init__(self):
19
+ self.logger = logging.getLogger("ENTITY_CHUNK_MAPPER")
20
+ self.extension_to_language = {
21
+ 'py': Language.PYTHON,
22
+ 'pyw': Language.PYTHON,
23
+ 'c': Language.C,
24
+ 'h': Language.C,
25
+ 'cpp': Language.CPP,
26
+ 'cc': Language.CPP,
27
+ 'cxx': Language.CPP,
28
+ 'hpp': Language.CPP,
29
+ 'hh': Language.CPP,
30
+ 'hxx': Language.CPP,
31
+ 'java': Language.JAVA,
32
+ }
33
+
34
+ def _detect_language(self, file_name: Optional[str] = None) -> Language:
35
+ """
36
+ Detect the programming language from file extension
37
+
38
+ Args:
39
+ file_name: Name of the file (optional)
40
+
41
+ Returns:
42
+ Language enum value, defaults to PYTHON if not detected
43
+ """
44
+ if file_name:
45
+ extension = file_name.split('.')[-1].lower()
46
+ return self.extension_to_language.get(extension, Language.PYTHON)
47
+ return Language.PYTHON
48
+
49
+ def _is_comment_or_docstring(self, line: str, in_docstring: bool, language: Language) -> Tuple[bool, bool]:
50
+ """
51
+ Check if a line is a comment or part of a docstring/multi-line comment
52
+
53
+ Args:
54
+ line: The line to check
55
+ in_docstring: Whether we're currently inside a docstring/multi-line comment
56
+ language: The programming language
57
+
58
+ Returns:
59
+ Tuple of (is_comment_or_docstring, new_in_docstring_state)
60
+ """
61
+ stripped = line.strip()
62
+
63
+ if language == Language.PYTHON:
64
+ # Check for single-line comments
65
+ if stripped.startswith('#'):
66
+ return True, in_docstring
67
+
68
+ # Check for docstring delimiters (""" or ''')
69
+ triple_double = '"""'
70
+ triple_single = "'''"
71
+
72
+ # Count occurrences of triple quotes
73
+ if triple_double in stripped or triple_single in stripped:
74
+ # Check if it's a single-line docstring
75
+ if (stripped.count(triple_double) >= 2 or
76
+ stripped.count(triple_single) >= 2):
77
+ # Single-line docstring
78
+ return True, in_docstring
79
+ else:
80
+ # Toggle docstring state
81
+ return True, not in_docstring
82
+
83
+ # If we're in a docstring, this line is part of it
84
+ if in_docstring:
85
+ return True, in_docstring
86
+
87
+ elif language in [Language.C, Language.CPP, Language.JAVA]:
88
+ # Check for single-line comments
89
+ if stripped.startswith('//'):
90
+ return True, in_docstring
91
+
92
+ # Check for multi-line comment delimiters /* */
93
+ if '/*' in line and '*/' in line:
94
+ # Single-line multi-line comment
95
+ return True, in_docstring
96
+ elif '/*' in line:
97
+ # Start of multi-line comment
98
+ return True, True
99
+ elif '*/' in line:
100
+ # End of multi-line comment
101
+ return True, False
102
+
103
+ # If we're in a multi-line comment
104
+ if in_docstring:
105
+ return True, in_docstring
106
+
107
+ return False, in_docstring
108
+
109
+ def _get_code_lines(self, chunk_lines: List[str], language: Language) -> List[str]:
110
+ """
111
+ Filter out comments and docstrings from chunk lines
112
+
113
+ Args:
114
+ chunk_lines: List of lines in the chunk
115
+ language: The programming language
116
+
117
+ Returns:
118
+ List of lines that are actual code (not comments or docstrings)
119
+ """
120
+ code_lines = []
121
+ in_docstring = False
122
+
123
+ for line in chunk_lines:
124
+ is_doc, in_docstring = self._is_comment_or_docstring(line, in_docstring, language)
125
+ if not is_doc:
126
+ code_lines.append(line)
127
+
128
+ return code_lines
129
+
130
+ def _is_valid_identifier_match(self, text: str, identifier: str, position: int) -> bool:
131
+ """
132
+ Check if an identifier match at a position is valid (not part of another word)
133
+
134
+ Args:
135
+ text: The text containing the identifier
136
+ identifier: The identifier to check
137
+ position: The position where the identifier was found
138
+
139
+ Returns:
140
+ True if this is a valid standalone identifier match
141
+ """
142
+ # Check character before (if exists)
143
+ if position > 0:
144
+ char_before = text[position - 1]
145
+ if char_before.isalnum() or char_before == '_':
146
+ return False
147
+
148
+ # Check character after (if exists)
149
+ end_pos = position + len(identifier)
150
+ if end_pos < len(text):
151
+ char_after = text[end_pos]
152
+ if char_after.isalnum() or char_after == '_':
153
+ return False
154
+
155
+ return True
156
+
157
+ def _contains_identifier(self, line: str, identifier: str) -> bool:
158
+ """
159
+ Check if a line contains an identifier as a standalone word (not part of another word)
160
+
161
+ Args:
162
+ line: The line to check
163
+ identifier: The identifier to find
164
+
165
+ Returns:
166
+ True if the identifier appears as a standalone word
167
+ """
168
+ # Use word boundary regex for precise matching
169
+ pattern = r'\b' + re.escape(identifier) + r'\b'
170
+ return bool(re.search(pattern, line))
171
+
172
+
173
+ def find_entity_in_chunks(self, entity_name: str, chunks: List[str], entity_type: str = None,
174
+ file_name: Optional[str] = None) -> Set[int]:
175
+ """
176
+ Find which chunks contain a specific entity declaration or call
177
+
178
+ Args:
179
+ entity_name: Name of the entity to find
180
+ chunks: List of code chunks
181
+ entity_type: Type of entity (class, function, method, variable)
182
+ file_name: Name of the file to detect language (optional)
183
+
184
+ Returns:
185
+ Set of chunk indices that contain this entity
186
+ """
187
+ matching_chunks = set()
188
+ language = self._detect_language(file_name)
189
+
190
+ # Split the entity name to handle nested entities like "ClassName.method"
191
+ # For Java/C++, also handle :: separator
192
+ if '::' in entity_name:
193
+ parts = entity_name.split('::')
194
+ else:
195
+ parts = entity_name.split('.')
196
+ base_name = parts[-1] # The actual identifier
197
+
198
+ for chunk_idx, chunk in enumerate(chunks):
199
+ chunk_lines = chunk.strip().split('\n')
200
+
201
+ # Look for different patterns based on entity type
202
+ if self._entity_appears_in_chunk(entity_name, base_name, chunk, chunk_lines, entity_type, language):
203
+ matching_chunks.add(chunk_idx)
204
+
205
+ return matching_chunks
206
+
207
+ def _entity_appears_in_chunk(self, full_name: str, base_name: str, chunk: str, chunk_lines: List[str],
208
+ entity_type: str, language: Language) -> bool:
209
+ """Check if an entity appears in a specific chunk (excluding comments and docstrings)"""
210
+
211
+ # Filter out comments and docstrings
212
+ code_lines = self._get_code_lines(chunk_lines, language)
213
+
214
+ # If no code lines remain, entity doesn't appear in actual code
215
+ if not code_lines:
216
+ return False
217
+
218
+ # Language-specific entity matching
219
+ if language == Language.PYTHON:
220
+ return self._entity_appears_in_python(full_name, base_name, code_lines, entity_type)
221
+ elif language in [Language.C, Language.CPP]:
222
+ return self._entity_appears_in_c_cpp(full_name, base_name, code_lines, entity_type)
223
+ elif language == Language.JAVA:
224
+ return self._entity_appears_in_java(full_name, base_name, code_lines, entity_type)
225
+
226
+ return False
227
+
228
+ def _entity_appears_in_python(self, full_name: str, base_name: str, code_lines: List[str],
229
+ entity_type: str) -> bool:
230
+ """Check if entity appears in Python code"""
231
+
232
+ if entity_type == "class":
233
+ # Look for class definition
234
+ for line in code_lines:
235
+ stripped = line.strip()
236
+ if re.match(rf'class\s+{re.escape(base_name)}[\s:(]', stripped):
237
+ return True
238
+
239
+ elif entity_type == "api_endpoint":
240
+ # Look for API endpoint definition - the function decorated with @app.get, @app.post, etc.
241
+ # We look for the function definition itself
242
+ for line in code_lines:
243
+ stripped = line.strip()
244
+ # Match the function definition with the endpoint name
245
+ if re.match(rf'(async\s+)?def\s+{re.escape(base_name)}\s*\(', stripped):
246
+ return True
247
+ # Also check for decorators that might reference the endpoint
248
+ if re.search(rf'@\w+\.(get|post|put|delete|patch|options|head)\s*\(', stripped):
249
+ return True
250
+
251
+ elif entity_type == "function":
252
+ # Look for function definition (not method)
253
+ for line in code_lines:
254
+ stripped = line.strip()
255
+ # Check it's not indented (not a method)
256
+ if not line.startswith(" ") and not line.startswith("\t"):
257
+ if re.match(rf'(async\s+)?def\s+{re.escape(base_name)}\s*\(', stripped):
258
+ return True
259
+
260
+ elif entity_type == "method":
261
+ # Look for method definition (indented def)
262
+ method_name = full_name.split('.')[-1]
263
+ for line in code_lines:
264
+ stripped = line.strip()
265
+ # Check it's indented (is a method)
266
+ if line.startswith(" ") or line.startswith("\t"):
267
+ if re.match(rf'(async\s+)?def\s+{re.escape(method_name)}\s*\(', stripped):
268
+ return True
269
+
270
+ elif entity_type == "variable":
271
+ # Look for variable assignment or usage
272
+ if "." in full_name:
273
+ parts = full_name.split('.')
274
+ attr_name = parts[-1]
275
+ for line in code_lines:
276
+ if re.search(rf'\.\s*{re.escape(attr_name)}\b', line):
277
+ return True
278
+ else:
279
+ for line in code_lines:
280
+ stripped = line.strip()
281
+ if re.match(rf'{re.escape(base_name)}\s*[=:]', stripped):
282
+ return True
283
+
284
+ # For called entities, look for usage patterns
285
+ if entity_type in ["function", "method"] or entity_type is None:
286
+ for line in code_lines:
287
+ if re.search(rf'\b{re.escape(base_name)}\s*\(', line):
288
+ return True
289
+
290
+ if entity_type == "class" or entity_type is None:
291
+ for line in code_lines:
292
+ if re.search(rf'\b{re.escape(base_name)}\s*\(', line):
293
+ return True
294
+
295
+ # General usage as identifier
296
+ if entity_type is None or entity_type == "variable":
297
+ for line in code_lines:
298
+ if self._contains_identifier(line, base_name):
299
+ return True
300
+
301
+ return False
302
+
303
+ def _extract_using_namespace_directives(self, code_lines: List[str]) -> List[str]:
304
+ """
305
+ Extract using namespace directives from C++ code.
306
+ Returns a list of namespace names that are being imported.
307
+ """
308
+ namespaces = []
309
+ for line in code_lines:
310
+ stripped = line.strip()
311
+ # Match "using namespace <name>;"
312
+ match = re.match(r'using\s+namespace\s+([a-zA-Z_][a-zA-Z0-9_:]*)\s*;', stripped)
313
+ if match:
314
+ namespaces.append(match.group(1))
315
+ return namespaces
316
+
317
+ def _entity_appears_in_c_cpp(self, full_name: str, base_name: str, code_lines: List[str],
318
+ entity_type: str) -> bool:
319
+ """Check if entity appears in C/C++ code"""
320
+
321
+ # Extract using namespace directives
322
+ using_namespaces = self._extract_using_namespace_directives(code_lines)
323
+
324
+ # Check if the full_name matches any imported namespace + base_name
325
+ # e.g., if full_name is "math::Calculator" and we have "using namespace math",
326
+ # then "Calculator" in code should match
327
+ namespace_match = False
328
+ if '::' in full_name:
329
+ for ns in using_namespaces:
330
+ # Check if full_name starts with this namespace
331
+ if full_name.startswith(ns + '::'):
332
+ namespace_match = True
333
+ break
334
+
335
+ if entity_type == "class":
336
+ # Look for class/struct definition
337
+ for line in code_lines:
338
+ stripped = line.strip()
339
+ if re.match(rf'(class|struct)\s+{re.escape(base_name)}[\s:{{]', stripped):
340
+ return True
341
+
342
+ elif entity_type == "function":
343
+ # Look for function definition or declaration
344
+ for line in code_lines:
345
+ stripped = line.strip()
346
+ # Match function patterns: return_type function_name(
347
+ # Also handle constructors and destructors
348
+ if (re.search(rf'\b{re.escape(base_name)}\s*\(', stripped) and
349
+ not stripped.startswith('//')):
350
+ # Additional check: likely a function if followed by parameters
351
+ return True
352
+
353
+ elif entity_type == "method":
354
+ # Look for method definition (with class scope)
355
+ method_name = full_name.split('::')[-1] if '::' in full_name else full_name.split('.')[-1]
356
+ for line in code_lines:
357
+ stripped = line.strip()
358
+ # Match ClassName::methodName( or just methodName( inside class
359
+ if re.search(rf'\b{re.escape(method_name)}\s*\(', stripped):
360
+ return True
361
+
362
+ elif entity_type == "variable":
363
+ # Look for variable declaration or usage
364
+ for line in code_lines:
365
+ stripped = line.strip()
366
+ # Match variable declarations and assignments
367
+ if re.search(rf'\b{re.escape(base_name)}\b', stripped):
368
+ return True
369
+
370
+ # For called entities, look for usage patterns
371
+ if entity_type in ["function", "method"] or entity_type is None:
372
+ for line in code_lines:
373
+ if re.search(rf'\b{re.escape(base_name)}\s*\(', line):
374
+ return True
375
+
376
+ if entity_type == "class" or entity_type is None:
377
+ # Look for instantiation or usage
378
+ for line in code_lines:
379
+ if re.search(rf'\b{re.escape(base_name)}\b', line):
380
+ # If we found base_name and there's a namespace match, this is a match
381
+ if namespace_match:
382
+ return True
383
+ # If full_name doesn't have a namespace, it's a direct match
384
+ if '::' not in full_name:
385
+ return True
386
+
387
+ # General usage as identifier
388
+ if entity_type is None or entity_type == "variable":
389
+ for line in code_lines:
390
+ if self._contains_identifier(line, base_name):
391
+ # If we found base_name and there's a namespace match, this is a match
392
+ if namespace_match:
393
+ return True
394
+ # If full_name doesn't have a namespace, it's a direct match
395
+ if '::' not in full_name:
396
+ return True
397
+
398
+ return False
399
+
400
+ def _entity_appears_in_java(self, full_name: str, base_name: str, code_lines: List[str],
401
+ entity_type: str) -> bool:
402
+ """Check if entity appears in Java code"""
403
+
404
+ if entity_type == "class":
405
+ # Look for class/interface/enum definition
406
+ for line in code_lines:
407
+ stripped = line.strip()
408
+ if re.match(rf'(public|private|protected)?\s*(class|interface|enum)\s+{re.escape(base_name)}[\s<{{]', stripped):
409
+ return True
410
+ # Without modifier
411
+ if re.match(rf'(class|interface|enum)\s+{re.escape(base_name)}[\s<{{]', stripped):
412
+ return True
413
+
414
+ elif entity_type == "api_endpoint":
415
+ # Look for API endpoint definition - the method with Spring annotations
416
+ # Extract just the method name from the full qualified name (e.g., "com.example.Controller::method" -> "method")
417
+ method_name = base_name.split('::')[-1] if '::' in base_name else base_name
418
+ for line in code_lines:
419
+ stripped = line.strip()
420
+ # Match the method definition
421
+ if re.search(rf'\b{re.escape(method_name)}\s*\(', stripped):
422
+ return True
423
+ # Also check for Spring annotations
424
+ if re.search(r'@(GetMapping|PostMapping|PutMapping|DeleteMapping|PatchMapping|RequestMapping)', stripped):
425
+ return True
426
+
427
+ elif entity_type == "function":
428
+ # In Java, functions are methods
429
+ for line in code_lines:
430
+ stripped = line.strip()
431
+ # Match method signature patterns
432
+ if re.search(rf'\b{re.escape(base_name)}\s*\(', stripped):
433
+ return True
434
+
435
+ elif entity_type == "method":
436
+ # Look for method definition
437
+ method_name = full_name.split('.')[-1]
438
+ for line in code_lines:
439
+ stripped = line.strip()
440
+ if re.search(rf'\b{re.escape(method_name)}\s*\(', stripped):
441
+ return True
442
+
443
+ elif entity_type == "variable":
444
+ # Look for variable declaration or usage
445
+ for line in code_lines:
446
+ stripped = line.strip()
447
+ if re.search(rf'\b{re.escape(base_name)}\b', stripped):
448
+ return True
449
+
450
+ # For called entities, look for usage patterns
451
+ if entity_type in ["function", "method"] or entity_type is None:
452
+ for line in code_lines:
453
+ if re.search(rf'\b{re.escape(base_name)}\s*\(', line):
454
+ return True
455
+
456
+ if entity_type == "class" or entity_type is None:
457
+ # Look for instantiation (new ClassName) or usage
458
+ for line in code_lines:
459
+ if re.search(rf'\b{re.escape(base_name)}\b', line):
460
+ return True
461
+
462
+ # General usage as identifier
463
+ if entity_type is None or entity_type == "variable":
464
+ for line in code_lines:
465
+ if self._contains_identifier(line, base_name):
466
+ return True
467
+
468
+ return False
469
+
470
+ def map_entities_to_chunks(self, declared_entities: List[Dict[str, Any]],
471
+ called_entities: List[str],
472
+ chunks: List[str],
473
+ file_name: Optional[str] = None) -> Tuple[Dict[int, List[Dict[str, Any]]],
474
+ Dict[int, List[str]]]:
475
+ """
476
+ Map file-level entities back to their respective chunks
477
+
478
+ Args:
479
+ declared_entities: List of declared entities from file-level extraction
480
+ called_entities: List of called entities from file-level extraction
481
+ chunks: List of code chunks
482
+ file_name: Name of the file to detect language (optional)
483
+
484
+ Returns:
485
+ Tuple of (chunk_declared_entities, chunk_called_entities)
486
+ - chunk_declared_entities: Dict mapping chunk_index -> list of declared entities
487
+ - chunk_called_entities: Dict mapping chunk_index -> list of called entities
488
+ """
489
+ chunk_declared = {}
490
+ chunk_called = {}
491
+
492
+ # Initialize empty lists for all chunks
493
+ for i in range(len(chunks)):
494
+ chunk_declared[i] = []
495
+ chunk_called[i] = []
496
+
497
+ # Map declared entities to chunks
498
+ for entity in declared_entities:
499
+ entity_name = entity.get("name", "")
500
+ entity_type = entity.get("type", "")
501
+
502
+ matching_chunks = self.find_entity_in_chunks(entity_name, chunks, entity_type, file_name)
503
+
504
+ # Add entity to matching chunks
505
+ for chunk_idx in matching_chunks:
506
+ chunk_declared[chunk_idx].append(entity)
507
+
508
+ # Map called entities to chunks
509
+ for called_entity in called_entities:
510
+ matching_chunks = self.find_entity_in_chunks(called_entity, chunks, None, file_name)
511
+
512
+ # Add called entity to matching chunks
513
+ for chunk_idx in matching_chunks:
514
+ if called_entity not in chunk_called[chunk_idx]:
515
+ chunk_called[chunk_idx].append(called_entity)
516
+
517
+ return chunk_declared, chunk_called
RepoKnowledgeGraphLib/EntityExtractor.py ADDED
@@ -0,0 +1,2032 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import os
3
+ import logging
4
+ import tempfile
5
+ from typing import List, Dict, Any, Tuple, Optional
6
+ from clang import cindex
7
+ import javalang
8
+ import javalang.tree as T
9
+ import esprima
10
+ from bs4 import BeautifulSoup
11
+ import tree_sitter_rust as ts_rust
12
+ from tree_sitter import Language, Parser
13
+ import re
14
+ from .utils.path_utils import generate_entity_aliases
15
+
16
+
17
+
18
+ LOGGER_NAME = "AST_ENTITY_EXTRACTOR"
19
+ logger = logging.getLogger(LOGGER_NAME)
20
+
21
+
22
+ class BaseASTEntityExtractor:
23
+ def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
24
+ """
25
+ Extract entities from source code.
26
+
27
+ Args:
28
+ code: Source code as string
29
+ file_path: Optional path to the source file (for better context and include resolution)
30
+
31
+ Returns:
32
+ Tuple of (declared_entities, called_entities)
33
+ """
34
+ raise NotImplementedError
35
+
36
+
37
+ # Add a reset contract so extractors can be reused safely
38
+ def reset(self) -> None:
39
+ """
40
+ Reset internal state so the extractor instance can be reused.
41
+ Concrete extractors should override this to clear their buffers.
42
+ """
43
+ raise NotImplementedError
44
+
45
+ class HTMLEntityExtractor(BaseASTEntityExtractor):
46
+ """
47
+ Hybrid HTML AST-based entity extractor.
48
+
49
+ Responsibilities:
50
+ β€’ Parse HTML into a tree
51
+ β€’ Extract declared DOM entities (ids, names, classes)
52
+ β€’ Extract JavaScript calls from inline event handlers
53
+ β€’ Extract JS entities from <script> tags
54
+ β€’ Integrate cleanly with the hybrid AST graph linker
55
+ """
56
+
57
+ EVENT_ATTR_PREFIX = "on" # e.g., onclick, onsubmit, etc.
58
+
59
+ def __init__(self):
60
+ self.js_extractor = JavaScriptEntityExtractor()
61
+ self.reset()
62
+
63
+ # --------------------------------------
64
+ # Core interface
65
+ # --------------------------------------
66
+ def reset(self):
67
+ self.declared_entities: List[Dict[str, str]] = []
68
+ self.called_entities: List[str] = []
69
+
70
+ def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, str]], List[str]]:
71
+ """Main entry point: parse HTML and extract entities."""
72
+ self.reset()
73
+ try:
74
+ soup = BeautifulSoup(code, "html.parser")
75
+ except Exception as e:
76
+ print(f"[HTMLEntityExtractor] Parsing error: {e}")
77
+ return [], []
78
+
79
+ # --- DOM element declarations ---
80
+ for tag in soup.find_all(True):
81
+ self._handle_tag_declaration(tag)
82
+ self._handle_event_attributes(tag)
83
+
84
+ # --- <script> tags (inline + external) ---
85
+ for script in soup.find_all("script"):
86
+ self._handle_script(script)
87
+
88
+ # --- Deduplication ---
89
+ self.declared_entities = self._deduplicate_dicts(self.declared_entities)
90
+ self.called_entities = self._deduplicate_list(self.called_entities)
91
+
92
+ return self.declared_entities, self.called_entities
93
+
94
+ # --------------------------------------
95
+ # Tag & attribute handlers
96
+ # --------------------------------------
97
+ def _handle_tag_declaration(self, tag):
98
+ """Extract declared DOM elements (id, name, class)."""
99
+ if tag.has_attr("id"):
100
+ self.declared_entities.append({"name": tag["id"], "type": "element"})
101
+
102
+ if tag.has_attr("name"):
103
+ self.declared_entities.append({"name": tag["name"], "type": "element"})
104
+
105
+ if tag.has_attr("class"):
106
+ classes = tag["class"]
107
+ if isinstance(classes, list):
108
+ for c in classes:
109
+ self.declared_entities.append({"name": c, "type": "class"})
110
+ elif isinstance(classes, str):
111
+ self.declared_entities.append({"name": classes, "type": "class"})
112
+
113
+ def _handle_event_attributes(self, tag):
114
+ """Extract JS calls from inline event attributes."""
115
+ if not self.js_extractor:
116
+ return
117
+ for attr, value in tag.attrs.items():
118
+ if attr.lower().startswith(self.EVENT_ATTR_PREFIX) and isinstance(value, str):
119
+ try:
120
+ _, called = self.js_extractor.extract_entities(value)
121
+ self.called_entities.extend(called)
122
+ except Exception as e:
123
+ print(f"[HTMLEntityExtractor] JS parse error in {attr}: {e}")
124
+
125
+ def _handle_script(self, script):
126
+ """Extract JS entities from <script> blocks or src attributes."""
127
+ if script.has_attr("src"):
128
+ src = script["src"]
129
+ self.called_entities.append(src)
130
+ return
131
+
132
+ if not self.js_extractor:
133
+ return
134
+
135
+ js_code = (script.string or "").strip()
136
+ if js_code:
137
+ try:
138
+ declared, called = self.js_extractor.extract_entities(js_code)
139
+ self.declared_entities.extend(declared)
140
+ self.called_entities.extend(called)
141
+ except Exception as e:
142
+ print(f"[HTMLEntityExtractor] JS parse error in <script>: {e}")
143
+
144
+ # --------------------------------------
145
+ # Helpers
146
+ # --------------------------------------
147
+ @staticmethod
148
+ def _deduplicate_dicts(dicts: List[Dict]) -> List[Dict]:
149
+ seen = set()
150
+ result = []
151
+ for d in dicts:
152
+ key = tuple(sorted(d.items()))
153
+ if key not in seen:
154
+ seen.add(key)
155
+ result.append(d)
156
+ return result
157
+
158
+ @staticmethod
159
+ def _deduplicate_list(items: List[str]) -> List[str]:
160
+ seen = set()
161
+ result = []
162
+ for i in items:
163
+ if i not in seen:
164
+ seen.add(i)
165
+ result.append(i)
166
+ return result
167
+
168
+
169
+ class JavaEntityExtractor(BaseASTEntityExtractor):
170
+ """
171
+ Extract declared and called entities from Java code using javalang.
172
+ Produces the same (declared_entities, called_entities) structure as other extractors.
173
+ """
174
+
175
+ def __init__(self):
176
+ self.reset()
177
+
178
+ def reset(self) -> None:
179
+ self.declared_entities: List[Dict[str, Any]] = []
180
+ self.called_entities: List[str] = []
181
+ self.current_package: Optional[str] = None
182
+ self.scope_stack: List[str] = []
183
+ self.api_endpoints: List[Dict[str, Any]] = [] # Track API endpoint definitions
184
+ self.current_class_base_path: Optional[str] = None # For @RequestMapping on class
185
+
186
+ # -----------------------------------------------------------
187
+ # Helpers
188
+ # -----------------------------------------------------------
189
+
190
+ def _qualified(self, name: str) -> str:
191
+ if not name:
192
+ return ""
193
+ scope = "::".join(self.scope_stack)
194
+ return f"{scope}::{name}" if scope else name
195
+
196
+ def _walk_type(self, t):
197
+ """Return string representation of a type node."""
198
+ if not t:
199
+ return "unknown"
200
+ if isinstance(t, str):
201
+ return t
202
+ if hasattr(t, "name"):
203
+ name = t.name
204
+ if getattr(t, "arguments", None):
205
+ args = [self._walk_type(a.type) for a in t.arguments if hasattr(a, "type")]
206
+ name += "<" + ", ".join(args) + ">"
207
+ return name
208
+ return "unknown"
209
+
210
+ # -----------------------------------------------------------
211
+ # Main AST traversal
212
+ # -----------------------------------------------------------
213
+
214
+ def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
215
+ self.reset()
216
+
217
+ try:
218
+ tree = javalang.parse.parse(code)
219
+ except javalang.parser.JavaSyntaxError as e:
220
+ logger.error(f"Syntax error in Java code: {e}")
221
+ return [], []
222
+ except Exception as e:
223
+ logger.error(f"Error parsing Java code: {e}", exc_info=True)
224
+ return [], []
225
+
226
+ # --- Package ---
227
+ if tree.package:
228
+ self.current_package = tree.package.name
229
+
230
+ # --- Imports ---
231
+ for imp in tree.imports:
232
+ self.called_entities.append(imp.path)
233
+
234
+ # --- Types (classes, interfaces, enums) ---
235
+ for type_decl in tree.types:
236
+ self._visit_type(type_decl)
237
+
238
+ # Deduplicate
239
+ seen_decl = set()
240
+ unique_declared = []
241
+ for e in self.declared_entities:
242
+ key = (e.get("name"), e.get("type"), e.get("dtype"))
243
+ if key not in seen_decl:
244
+ unique_declared.append(e)
245
+ seen_decl.add(key)
246
+
247
+ unique_called = list(dict.fromkeys(self.called_entities))
248
+ return unique_declared, unique_called
249
+
250
+ # -----------------------------------------------------------
251
+ # Visitors for different node types
252
+ # -----------------------------------------------------------
253
+
254
+ def _visit_type(self, node):
255
+ if isinstance(node, javalang.tree.ClassDeclaration):
256
+ self._visit_class(node)
257
+ elif isinstance(node, javalang.tree.InterfaceDeclaration):
258
+ self._visit_interface(node)
259
+ elif isinstance(node, javalang.tree.EnumDeclaration):
260
+ self._visit_enum(node)
261
+
262
+ def _visit_class(self, node):
263
+ full_name = node.name
264
+ if self.current_package:
265
+ full_name = f"{self.current_package}.{node.name}"
266
+ qualified = self._qualified(full_name)
267
+
268
+ self.declared_entities.append({"name": qualified, "type": "class"})
269
+
270
+ # Check for REST controller annotations and extract base path
271
+ old_base_path = self.current_class_base_path
272
+ if node.annotations:
273
+ for annotation in node.annotations:
274
+ if annotation.name in {'RestController', 'Controller'}:
275
+ # Mark as REST controller
276
+ pass
277
+ elif annotation.name == 'RequestMapping':
278
+ # Extract base path from class-level @RequestMapping
279
+ self.current_class_base_path = self._extract_path_from_annotation(annotation)
280
+
281
+ # Inheritance
282
+ if node.extends:
283
+ self.called_entities.append(self._walk_type(node.extends))
284
+ for impl in node.implements or []:
285
+ self.called_entities.append(self._walk_type(impl))
286
+
287
+ self.scope_stack.append(full_name)
288
+ for member in node.body:
289
+ self._visit_member(member)
290
+ self.scope_stack.pop()
291
+
292
+ # Restore the previous base path
293
+ self.current_class_base_path = old_base_path
294
+
295
+ def _visit_interface(self, node):
296
+ full_name = node.name
297
+ if self.current_package:
298
+ full_name = f"{self.current_package}.{node.name}"
299
+ qualified = self._qualified(full_name)
300
+ self.declared_entities.append({"name": qualified, "type": "interface"})
301
+
302
+ for impl in node.extends or []:
303
+ self.called_entities.append(self._walk_type(impl))
304
+
305
+ self.scope_stack.append(full_name)
306
+ for member in node.body:
307
+ self._visit_member(member)
308
+ self.scope_stack.pop()
309
+
310
+ def _visit_enum(self, node):
311
+ full_name = node.name
312
+ if self.current_package:
313
+ full_name = f"{self.current_package}.{node.name}"
314
+ qualified = self._qualified(full_name)
315
+ self.declared_entities.append({"name": qualified, "type": "enum"})
316
+
317
+ def _visit_member(self, node):
318
+
319
+ # --- Method ---
320
+ if isinstance(node, T.MethodDeclaration):
321
+ method_name = self._qualified(node.name)
322
+
323
+ # Check for API endpoint annotations
324
+ api_info = self._extract_api_endpoint_from_annotations(node)
325
+ if api_info:
326
+ self.declared_entities.append({
327
+ "name": method_name,
328
+ "type": "api_endpoint",
329
+ "endpoint": api_info.get("endpoint"),
330
+ "methods": api_info.get("methods")
331
+ })
332
+ self.api_endpoints.append({**api_info, "function": method_name})
333
+ else:
334
+ self.declared_entities.append({"name": method_name, "type": "method"})
335
+
336
+ for param in node.parameters:
337
+ ptype = self._walk_type(param.type)
338
+ pname = f"{method_name}.{param.name}"
339
+ self.declared_entities.append({
340
+ "name": pname,
341
+ "type": "variable",
342
+ "dtype": ptype
343
+ })
344
+
345
+ # Look for method calls in the body
346
+ if node.body:
347
+ self._find_calls(node.body)
348
+
349
+ # --- Constructor ---
350
+ elif isinstance(node, T.ConstructorDeclaration):
351
+ ctor_name = self._qualified(node.name)
352
+ self.declared_entities.append({"name": ctor_name, "type": "constructor"})
353
+ for param in node.parameters:
354
+ ptype = self._walk_type(param.type)
355
+ pname = f"{ctor_name}.{param.name}"
356
+ self.declared_entities.append({
357
+ "name": pname,
358
+ "type": "variable",
359
+ "dtype": ptype
360
+ })
361
+ if node.body:
362
+ self._find_calls(node.body)
363
+
364
+ # --- Field ---
365
+ elif isinstance(node, T.FieldDeclaration):
366
+ dtype = self._walk_type(node.type)
367
+ for decl in node.declarators:
368
+ var_name = self._qualified(decl.name)
369
+ self.declared_entities.append({
370
+ "name": var_name,
371
+ "type": "variable",
372
+ "dtype": dtype
373
+ })
374
+
375
+ # --- Nested class/interface ---
376
+ elif isinstance(node, (T.ClassDeclaration, T.InterfaceDeclaration)):
377
+ self._visit_type(node)
378
+
379
+ # -----------------------------------------------------------
380
+ # API Endpoint Detection
381
+ # -----------------------------------------------------------
382
+
383
+ def _extract_api_endpoint_from_annotations(self, method) -> Optional[Dict[str, Any]]:
384
+ """
385
+ Extract API endpoint information from Spring Boot method annotations.
386
+ Handles: @GetMapping, @PostMapping, @RequestMapping, etc.
387
+ """
388
+ if not method.annotations:
389
+ return None
390
+
391
+ for annotation in method.annotations:
392
+ annotation_name = annotation.name
393
+
394
+ if annotation_name in {'GetMapping', 'PostMapping', 'PutMapping', 'PatchMapping', 'DeleteMapping'}:
395
+ # Extract HTTP method from annotation name
396
+ http_method = annotation_name.replace('Mapping', '').upper()
397
+ path = self._extract_path_from_annotation(annotation)
398
+
399
+ if path:
400
+ # Combine with class-level base path if present
401
+ full_path = self._combine_paths(self.current_class_base_path, path)
402
+ return {
403
+ "endpoint": full_path,
404
+ "methods": [http_method],
405
+ "type": "api_endpoint_definition"
406
+ }
407
+
408
+ elif annotation_name == 'RequestMapping':
409
+ # @RequestMapping can specify multiple methods
410
+ path = self._extract_path_from_annotation(annotation)
411
+ methods = self._extract_methods_from_annotation(annotation)
412
+
413
+ if path:
414
+ full_path = self._combine_paths(self.current_class_base_path, path)
415
+ return {
416
+ "endpoint": full_path,
417
+ "methods": methods if methods else ['GET'], # Default to GET
418
+ "type": "api_endpoint_definition"
419
+ }
420
+
421
+ return None
422
+
423
+ def _extract_path_from_annotation(self, annotation) -> Optional[str]:
424
+ """Extract path/value from Spring annotation."""
425
+ if not annotation.element:
426
+ return None
427
+
428
+ # Handle @GetMapping("/path") - single value
429
+ if isinstance(annotation.element, T.Literal):
430
+ return annotation.element.value.strip('"')
431
+
432
+ # Handle @RequestMapping(value = "/path") or @RequestMapping(path = "/path")
433
+ if isinstance(annotation.element, list):
434
+ for elem in annotation.element:
435
+ if isinstance(elem, T.ElementValuePair):
436
+ if elem.name in {'value', 'path'}:
437
+ if isinstance(elem.value, T.Literal):
438
+ return elem.value.value.strip('"')
439
+ elif isinstance(elem.value, T.ElementArrayValue):
440
+ # Handle array: value = {"/path1", "/path2"}
441
+ if elem.value.values:
442
+ first_val = elem.value.values[0]
443
+ if isinstance(first_val, T.Literal):
444
+ return first_val.value.strip('"')
445
+
446
+ return None
447
+
448
+ def _extract_methods_from_annotation(self, annotation) -> List[str]:
449
+ """Extract HTTP methods from @RequestMapping annotation."""
450
+ methods = []
451
+
452
+ if isinstance(annotation.element, list):
453
+ for elem in annotation.element:
454
+ if isinstance(elem, T.ElementValuePair):
455
+ if elem.name == 'method':
456
+ # Handle method = RequestMethod.GET or method = {RequestMethod.GET, RequestMethod.POST}
457
+ if hasattr(elem.value, 'member'):
458
+ # Single method: RequestMethod.GET
459
+ methods.append(elem.value.member)
460
+ elif isinstance(elem.value, T.ElementArrayValue):
461
+ # Multiple methods: {RequestMethod.GET, RequestMethod.POST}
462
+ for val in elem.value.values:
463
+ if hasattr(val, 'member'):
464
+ methods.append(val.member)
465
+
466
+ return methods
467
+
468
+ def _combine_paths(self, base_path: Optional[str], path: str) -> str:
469
+ """Combine base path from class annotation with method path."""
470
+ if not base_path:
471
+ return path
472
+
473
+ # Normalize paths
474
+ base = base_path.rstrip('/')
475
+ path = path.lstrip('/')
476
+
477
+ return f"{base}/{path}" if path else base
478
+
479
+ # -----------------------------------------------------------
480
+ # Find method invocations
481
+ # -----------------------------------------------------------
482
+
483
+ def _find_calls(self, statements):
484
+ """Recursively find method and constructor calls inside Java AST nodes."""
485
+
486
+ def _recurse(node):
487
+ if isinstance(node, T.MethodInvocation):
488
+ if node.qualifier:
489
+ self.called_entities.append(f"{node.qualifier}.{node.member}")
490
+ else:
491
+ self.called_entities.append(node.member)
492
+ elif isinstance(node, T.ClassCreator):
493
+ self.called_entities.append(self._walk_type(node.type))
494
+
495
+ # Recurse into all children
496
+ if hasattr(node, '__dict__'):
497
+ for attr, val in vars(node).items():
498
+ if isinstance(val, list):
499
+ for child in val:
500
+ if isinstance(child, T.Node):
501
+ _recurse(child)
502
+ elif isinstance(val, T.Node):
503
+ _recurse(val)
504
+
505
+ if not statements:
506
+ return
507
+
508
+ if isinstance(statements, list):
509
+ for stmt in statements:
510
+ _recurse(stmt)
511
+ else:
512
+ _recurse(statements)
513
+
514
+
515
+ class JavaScriptEntityExtractor(BaseASTEntityExtractor):
516
+ """
517
+ Extract declared and called entities from JavaScript code using esprima.
518
+ Handles ES6+ syntax including classes, arrow functions, imports/exports.
519
+ Also detects API endpoint calls (fetch, axios, etc.).
520
+ """
521
+
522
+ # Common HTTP methods to detect
523
+ HTTP_METHODS = {'get', 'post', 'put', 'patch', 'delete', 'head', 'options'}
524
+
525
+ # API call patterns to detect
526
+ API_PATTERNS = {
527
+ 'fetch', # fetch('/api/users')
528
+ 'axios', # axios.get('/api/users')
529
+ '$http', # Angular $http
530
+ 'request', # request library
531
+ 'superagent', # superagent library
532
+ }
533
+
534
+ def __init__(self):
535
+ self.reset()
536
+
537
+ def reset(self) -> None:
538
+ self.declared_entities: List[Dict[str, Any]] = []
539
+ self.called_entities: List[str] = []
540
+ self.scope_stack: List[str] = []
541
+ self.api_calls: List[Dict[str, Any]] = [] # Track API endpoint calls
542
+
543
+ def _qualified(self, name: str) -> str:
544
+ """Return fully qualified name using current scope stack."""
545
+ if not name:
546
+ return ""
547
+ scope = ".".join(self.scope_stack)
548
+ return f"{scope}.{name}" if scope else name
549
+
550
+ def _get_function_name(self, node) -> Optional[str]:
551
+ """Extract function name from various function node types."""
552
+ if hasattr(node, 'id') and node.id:
553
+ return node.id.name
554
+ return None
555
+
556
+ def _walk_node(self, node):
557
+ """Recursively walk the AST and extract entities."""
558
+ if not node or not hasattr(node, 'type'):
559
+ return
560
+
561
+ node_type = node.type
562
+
563
+ # --- Function Declaration ---
564
+ if node_type == 'FunctionDeclaration':
565
+ func_name = self._get_function_name(node)
566
+ if func_name:
567
+ qualified = self._qualified(func_name)
568
+ self.declared_entities.append({"name": qualified, "type": "function"})
569
+
570
+ # Extract parameters
571
+ if hasattr(node, 'params'):
572
+ for param in node.params:
573
+ param_name = self._extract_pattern_name(param)
574
+ if param_name:
575
+ self.declared_entities.append({
576
+ "name": f"{qualified}.{param_name}",
577
+ "type": "variable",
578
+ "dtype": "unknown"
579
+ })
580
+
581
+ self.scope_stack.append(func_name)
582
+ if hasattr(node, 'body'):
583
+ self._walk_node(node.body)
584
+ self.scope_stack.pop()
585
+
586
+ # --- Arrow Function Expression ---
587
+ elif node_type == 'ArrowFunctionExpression':
588
+ # Arrow functions are typically assigned, handle in VariableDeclarator
589
+ if hasattr(node, 'params'):
590
+ for param in node.params:
591
+ param_name = self._extract_pattern_name(param)
592
+ # Note: can't fully qualify without parent context
593
+ if hasattr(node, 'body'):
594
+ self._walk_node(node.body)
595
+
596
+ # --- Function Expression ---
597
+ elif node_type == 'FunctionExpression':
598
+ func_name = self._get_function_name(node)
599
+ if func_name:
600
+ qualified = self._qualified(func_name)
601
+ self.declared_entities.append({"name": qualified, "type": "function"})
602
+ self.scope_stack.append(func_name)
603
+
604
+ if hasattr(node, 'params'):
605
+ for param in node.params:
606
+ param_name = self._extract_pattern_name(param)
607
+ if param_name and func_name:
608
+ self.declared_entities.append({
609
+ "name": f"{self._qualified(func_name)}.{param_name}",
610
+ "type": "variable",
611
+ "dtype": "unknown"
612
+ })
613
+
614
+ if hasattr(node, 'body'):
615
+ self._walk_node(node.body)
616
+
617
+ if func_name:
618
+ self.scope_stack.pop()
619
+
620
+ # --- Class Declaration ---
621
+ elif node_type == 'ClassDeclaration':
622
+ class_name = node.id.name if hasattr(node, 'id') and node.id else None
623
+ if class_name:
624
+ qualified = self._qualified(class_name)
625
+ self.declared_entities.append({"name": qualified, "type": "class"})
626
+
627
+ # Handle inheritance
628
+ if hasattr(node, 'superClass') and node.superClass:
629
+ if hasattr(node.superClass, 'name'):
630
+ self.called_entities.append(node.superClass.name)
631
+
632
+ self.scope_stack.append(class_name)
633
+ if hasattr(node, 'body') and hasattr(node.body, 'body'):
634
+ for method in node.body.body:
635
+ self._walk_node(method)
636
+ self.scope_stack.pop()
637
+
638
+ # --- Method Definition ---
639
+ elif node_type == 'MethodDefinition':
640
+ method_name = node.key.name if hasattr(node, 'key') and hasattr(node.key, 'name') else None
641
+ if method_name:
642
+ qualified = self._qualified(method_name)
643
+ self.declared_entities.append({"name": qualified, "type": "method"})
644
+
645
+ if hasattr(node, 'value') and hasattr(node.value, 'params'):
646
+ for param in node.value.params:
647
+ param_name = self._extract_pattern_name(param)
648
+ if param_name:
649
+ self.declared_entities.append({
650
+ "name": f"{qualified}.{param_name}",
651
+ "type": "variable",
652
+ "dtype": "unknown"
653
+ })
654
+
655
+ if hasattr(node, 'value'):
656
+ self._walk_node(node.value)
657
+
658
+ # --- Variable Declaration ---
659
+ elif node_type == 'VariableDeclaration':
660
+ if hasattr(node, 'declarations'):
661
+ for decl in node.declarations:
662
+ self._walk_node(decl)
663
+
664
+ # --- Variable Declarator ---
665
+ elif node_type == 'VariableDeclarator':
666
+ var_name = self._extract_pattern_name(node.id) if hasattr(node, 'id') else None
667
+ if var_name:
668
+ qualified = self._qualified(var_name)
669
+
670
+ # Check if it's a function assignment
671
+ if hasattr(node, 'init') and node.init:
672
+ if node.init.type in ('FunctionExpression', 'ArrowFunctionExpression'):
673
+ self.declared_entities.append({"name": qualified, "type": "function"})
674
+ self.scope_stack.append(var_name)
675
+ self._walk_node(node.init)
676
+ self.scope_stack.pop()
677
+ else:
678
+ self.declared_entities.append({
679
+ "name": qualified,
680
+ "type": "variable",
681
+ "dtype": "unknown"
682
+ })
683
+ self._walk_node(node.init)
684
+ else:
685
+ self.declared_entities.append({
686
+ "name": qualified,
687
+ "type": "variable",
688
+ "dtype": "unknown"
689
+ })
690
+
691
+ # --- Call Expression ---
692
+ elif node_type == 'CallExpression':
693
+ callee_name = self._extract_callee_name(node.callee) if hasattr(node, 'callee') else None
694
+ if callee_name:
695
+ self.called_entities.append(callee_name)
696
+
697
+ # Detect API endpoint calls
698
+ self._detect_api_call(node, callee_name)
699
+
700
+ # Walk arguments
701
+ if hasattr(node, 'arguments'):
702
+ for arg in node.arguments:
703
+ self._walk_node(arg)
704
+
705
+ # --- Member Expression ---
706
+ elif node_type == 'MemberExpression':
707
+ # Don't record as call, just traverse
708
+ if hasattr(node, 'object'):
709
+ self._walk_node(node.object)
710
+ if hasattr(node, 'property'):
711
+ self._walk_node(node.property)
712
+
713
+ # --- Import/Export ---
714
+ elif node_type == 'ImportDeclaration':
715
+ if hasattr(node, 'source') and hasattr(node.source, 'value'):
716
+ self.called_entities.append(node.source.value)
717
+
718
+ elif node_type == 'ExportNamedDeclaration':
719
+ if hasattr(node, 'declaration'):
720
+ self._walk_node(node.declaration)
721
+
722
+ elif node_type == 'ExportDefaultDeclaration':
723
+ if hasattr(node, 'declaration'):
724
+ self._walk_node(node.declaration)
725
+
726
+ # --- Recursive traversal for other nodes ---
727
+ else:
728
+ if hasattr(node, '__dict__'):
729
+ for attr, val in vars(node).items():
730
+ if isinstance(val, list):
731
+ for item in val:
732
+ if hasattr(item, 'type'):
733
+ self._walk_node(item)
734
+ elif hasattr(val, 'type'):
735
+ self._walk_node(val)
736
+
737
+ def _extract_pattern_name(self, pattern) -> Optional[str]:
738
+ """Extract name from various pattern types (Identifier, ObjectPattern, etc.)."""
739
+ if not pattern:
740
+ return None
741
+ if hasattr(pattern, 'type'):
742
+ if pattern.type == 'Identifier':
743
+ return pattern.name if hasattr(pattern, 'name') else None
744
+ elif pattern.type == 'RestElement':
745
+ return self._extract_pattern_name(pattern.argument) if hasattr(pattern, 'argument') else None
746
+ return None
747
+
748
+ def _extract_callee_name(self, callee) -> Optional[str]:
749
+ """Extract the name of the function being called."""
750
+ if not callee:
751
+ return None
752
+
753
+ if hasattr(callee, 'type'):
754
+ if callee.type == 'Identifier':
755
+ return callee.name if hasattr(callee, 'name') else None
756
+ elif callee.type == 'MemberExpression':
757
+ obj = self._extract_callee_name(callee.object) if hasattr(callee, 'object') else ""
758
+ prop = callee.property.name if hasattr(callee, 'property') and hasattr(callee.property, 'name') else ""
759
+ if obj and prop:
760
+ return f"{obj}.{prop}"
761
+ return prop or obj
762
+ return None
763
+
764
+ def _detect_api_call(self, call_node, callee_name: str):
765
+ """
766
+ Detect API endpoint calls in JavaScript code.
767
+ Handles patterns like:
768
+ - fetch('/api/users')
769
+ - axios.get('/api/users')
770
+ - axios.post('/api/users', data)
771
+ - request.get('/api/users')
772
+ """
773
+ if not callee_name or not hasattr(call_node, 'arguments'):
774
+ return
775
+
776
+ # Split callee name to check for patterns
777
+ parts = callee_name.split('.')
778
+ base = parts[0]
779
+ method = parts[-1].lower() if len(parts) > 1 else None
780
+
781
+ # Check if this is an API call
782
+ is_api_call = False
783
+ http_method = 'unknown'
784
+
785
+ # Pattern 1: fetch('/api/...')
786
+ if base == 'fetch':
787
+ is_api_call = True
788
+ http_method = 'GET' # Default for fetch
789
+
790
+ # Pattern 2: axios.get('/api/...'), request.post(...), etc.
791
+ elif base in self.API_PATTERNS and method in self.HTTP_METHODS:
792
+ is_api_call = True
793
+ http_method = method.upper()
794
+
795
+ # Pattern 3: axios('/api/...', {method: 'POST'})
796
+ elif base in self.API_PATTERNS and method is None:
797
+ is_api_call = True
798
+ http_method = 'GET' # Default
799
+
800
+ if not is_api_call:
801
+ return
802
+
803
+ # Extract the endpoint URL from arguments
804
+ if call_node.arguments:
805
+ first_arg = call_node.arguments[0]
806
+ endpoint = self._extract_string_literal(first_arg)
807
+
808
+ if endpoint:
809
+ # Store as a called entity with special type
810
+ self.called_entities.append(f"API:{http_method}:{endpoint}")
811
+
812
+ # Also track in api_calls for easier filtering
813
+ self.api_calls.append({
814
+ "endpoint": endpoint,
815
+ "method": http_method,
816
+ "type": "api_call"
817
+ })
818
+
819
+ def _extract_string_literal(self, node) -> Optional[str]:
820
+ """Extract string value from a Literal/TemplateLiteral node."""
821
+ if not node or not hasattr(node, 'type'):
822
+ return None
823
+
824
+ if node.type == 'Literal' and isinstance(node.value, str):
825
+ return node.value
826
+ elif node.type == 'TemplateLiteral':
827
+ # For template literals, we try to extract the quasi parts
828
+ # e.g., `/api/${version}/users` -> /api/{version}/users
829
+ if hasattr(node, 'quasis'):
830
+ parts = []
831
+ for i, quasi in enumerate(node.quasis):
832
+ if hasattr(quasi, 'value') and hasattr(quasi.value, 'raw'):
833
+ parts.append(quasi.value.raw)
834
+ if i < len(node.quasis) - 1:
835
+ parts.append('{param}')
836
+ return ''.join(parts)
837
+
838
+ return None
839
+
840
+ def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
841
+ self.reset()
842
+
843
+ try:
844
+ tree = esprima.parseScript(code, {'tolerant': True, 'loc': False})
845
+ except Exception as e:
846
+ # Try parsing as module if script fails
847
+ try:
848
+ tree = esprima.parseModule(code, {'tolerant': True, 'loc': False})
849
+ except Exception as e2:
850
+ logger.error(f"Failed to parse JavaScript code: {e2}")
851
+ return [], []
852
+
853
+ if hasattr(tree, 'body'):
854
+ for node in tree.body:
855
+ self._walk_node(node)
856
+
857
+ # Deduplicate
858
+ seen_decl = set()
859
+ unique_declared = []
860
+ for e in self.declared_entities:
861
+ key = (e.get("name"), e.get("type"), e.get("dtype"))
862
+ if key not in seen_decl:
863
+ unique_declared.append(e)
864
+ seen_decl.add(key)
865
+
866
+ unique_called = list(dict.fromkeys(self.called_entities))
867
+ return unique_declared, unique_called
868
+
869
+
870
+ class CEntityExtractor(BaseASTEntityExtractor):
871
+ """
872
+ Extract declared and called entities from C code using clang.cindex (libclang),
873
+ with filtering to ignore system headers.
874
+ """
875
+
876
+ def __init__(self):
877
+ self.index = cindex.Index.create()
878
+
879
+ def reset(self) -> None:
880
+ """No persistent state to reset, but method provided for interface consistency."""
881
+ pass
882
+
883
+ def _walk_cursor(self, cursor, declared, called, source_file):
884
+ """Recursively walk a clang Cursor, restricted to the main file."""
885
+ for c in cursor.get_children():
886
+ # --- Include directives ---
887
+ # Note: INCLUSION_DIRECTIVE nodes are at the root level and need special handling
888
+ if c.kind == cindex.CursorKind.INCLUSION_DIRECTIVE:
889
+ # Get the included file name
890
+ included_file = c.displayname
891
+ if included_file:
892
+ called.append(included_file)
893
+ continue
894
+
895
+ loc = c.location
896
+ if not loc.file or not source_file:
897
+ continue
898
+
899
+ # Skip system / external headers for other nodes
900
+ if os.path.abspath(loc.file.name) != os.path.abspath(source_file):
901
+ continue
902
+
903
+ # --- Declarations ---
904
+ if c.kind.is_declaration():
905
+ if c.kind in (cindex.CursorKind.FUNCTION_DECL, cindex.CursorKind.FUNCTION_TEMPLATE):
906
+ name = c.spelling or c.displayname
907
+ declared.append({"name": name, "type": "function"})
908
+ for p in c.get_arguments():
909
+ declared.append({
910
+ "name": f"{name}.{p.spelling}",
911
+ "type": "variable",
912
+ "dtype": p.type.spelling
913
+ })
914
+ elif c.kind == cindex.CursorKind.VAR_DECL:
915
+ declared.append({
916
+ "name": c.spelling,
917
+ "type": "variable",
918
+ "dtype": c.type.spelling
919
+ })
920
+
921
+ # Add the variable's type to called entities
922
+ # This captures struct references like "struct Point p;"
923
+ if c.type.spelling:
924
+ # Extract the base type name (remove const, &, *, struct keyword, etc.)
925
+ type_name = c.type.spelling.strip()
926
+ # Remove common qualifiers and keywords
927
+ type_name = type_name.replace('const', '').replace('&', '').replace('*', '').replace('struct', '').strip()
928
+ if type_name and not type_name in ['int', 'float', 'double', 'char', 'bool', 'void', 'long', 'short', 'unsigned', 'signed', 'size_t']:
929
+ called.append(type_name)
930
+ elif c.kind == cindex.CursorKind.STRUCT_DECL:
931
+ declared.append({"name": c.spelling or c.displayname, "type": "struct"})
932
+ elif c.kind == cindex.CursorKind.TYPEDEF_DECL:
933
+ declared.append({"name": c.spelling, "type": "typedef"})
934
+
935
+ # --- Calls ---
936
+ if c.kind == cindex.CursorKind.CALL_EXPR:
937
+ callee = None
938
+ for child in c.get_children():
939
+ if child.kind in (cindex.CursorKind.DECL_REF_EXPR, cindex.CursorKind.MEMBER_REF_EXPR):
940
+ callee = child.spelling
941
+ break
942
+ if callee:
943
+ called.append(callee)
944
+ else:
945
+ called.append(c.displayname or c.spelling)
946
+
947
+ # --- Recurse ---
948
+ self._walk_cursor(c, declared, called, source_file)
949
+
950
+ def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
951
+ declared, called = [], []
952
+
953
+ # If file_path is provided, use it directly for better include resolution
954
+ # Otherwise, create a temporary file
955
+ tf_name = None
956
+ temp_file = False
957
+
958
+ if file_path and os.path.exists(file_path):
959
+ tf_name = file_path
960
+ temp_file = False
961
+ else:
962
+ with tempfile.NamedTemporaryFile(suffix=".c", mode="w+", delete=False) as tf:
963
+ tf_name = tf.name
964
+ tf.write(code)
965
+ tf.flush()
966
+ temp_file = True
967
+
968
+ # Get the directory containing the file for include paths
969
+ include_dir = os.path.dirname(tf_name) if tf_name else None
970
+ args = ['-std=c11']
971
+ if include_dir:
972
+ args.append(f'-I{include_dir}')
973
+
974
+ try:
975
+ tu = self.index.parse(
976
+ tf_name,
977
+ args=args,
978
+ options=cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD
979
+ )
980
+ except Exception as e:
981
+ raise RuntimeError(f"libclang failed to parse translation unit: {e}")
982
+
983
+ self._walk_cursor(tu.cursor, declared, called, tf_name)
984
+
985
+ # Deduplicate
986
+ seen_decl = set()
987
+ unique_declared = []
988
+ for e in declared:
989
+ key = (e.get("name"), e.get("type"), e.get("dtype", None))
990
+ if key not in seen_decl:
991
+ unique_declared.append(e)
992
+ seen_decl.add(key)
993
+
994
+ unique_called = list(dict.fromkeys(called))
995
+
996
+ # Only delete if we created a temp file
997
+ if temp_file:
998
+ try:
999
+ os.unlink(tf_name)
1000
+ except Exception:
1001
+ pass
1002
+
1003
+ return unique_declared, unique_called
1004
+
1005
+
1006
+ class CppEntityExtractor(BaseASTEntityExtractor):
1007
+ """
1008
+ Extract declared and called entities from C++ code using clang.cindex (libclang),
1009
+ including classes, namespaces, and methods.
1010
+ """
1011
+
1012
+ def __init__(self):
1013
+ self.index = cindex.Index.create()
1014
+ self.reset()
1015
+
1016
+ def reset(self) -> None:
1017
+ self.declared_entities = []
1018
+ self.called_entities = []
1019
+ self.scope_stack = []
1020
+
1021
+ def _qualified(self, name: str) -> str:
1022
+ """Return fully qualified name using current scope stack."""
1023
+ if not name:
1024
+ return ""
1025
+ if not self.scope_stack:
1026
+ return name
1027
+ return "::".join(self.scope_stack + [name])
1028
+
1029
+ def _walk_cursor(self, cursor, source_file: str):
1030
+ for c in cursor.get_children():
1031
+ # --- Include directives ---
1032
+ # Note: INCLUSION_DIRECTIVE nodes are at the root level and need special handling
1033
+ if c.kind == cindex.CursorKind.INCLUSION_DIRECTIVE:
1034
+ # Get the included file name
1035
+ included_file = c.displayname
1036
+ if included_file:
1037
+ self.called_entities.append(included_file)
1038
+ continue
1039
+
1040
+ kind = c.kind
1041
+
1042
+ # --- Namespace --- (process before location check)
1043
+ if kind == cindex.CursorKind.NAMESPACE:
1044
+ if c.spelling: # Only add non-empty namespace names
1045
+ self.scope_stack.append(c.spelling)
1046
+ self._walk_cursor(c, source_file)
1047
+ if c.spelling:
1048
+ self.scope_stack.pop()
1049
+ continue
1050
+
1051
+ # Check location for other node types
1052
+ loc = c.location
1053
+ # Skip nodes from other files, but allow nodes without location info
1054
+ if loc.file and os.path.abspath(loc.file.name) != os.path.abspath(source_file):
1055
+ continue
1056
+
1057
+ # --- Class / Struct ---
1058
+ if kind in (cindex.CursorKind.CLASS_DECL, cindex.CursorKind.STRUCT_DECL):
1059
+ # Only process if it has a name
1060
+ if c.spelling:
1061
+ # Check if it's a definition (not a forward declaration)
1062
+ is_def = c.is_definition() if hasattr(c, 'is_definition') else True
1063
+ if is_def:
1064
+ full_name = self._qualified(c.spelling)
1065
+ self.declared_entities.append({"name": full_name, "type": "class"})
1066
+
1067
+ # Handle base classes (inheritance)
1068
+ for base in c.get_children():
1069
+ if base.kind == cindex.CursorKind.CXX_BASE_SPECIFIER:
1070
+ if base.spelling:
1071
+ self.called_entities.append(base.spelling)
1072
+
1073
+ self.scope_stack.append(c.spelling)
1074
+ self._walk_cursor(c, source_file)
1075
+ self.scope_stack.pop()
1076
+ continue
1077
+
1078
+ # --- Methods ---
1079
+ if kind in (cindex.CursorKind.CXX_METHOD, cindex.CursorKind.CONSTRUCTOR, cindex.CursorKind.DESTRUCTOR):
1080
+ if c.spelling: # Only process if it has a name
1081
+ full_name = self._qualified(c.spelling)
1082
+ self.declared_entities.append({"name": full_name, "type": "method"})
1083
+
1084
+ for p in c.get_arguments():
1085
+ if p.spelling: # Only add parameters with names
1086
+ self.declared_entities.append({
1087
+ "name": f"{full_name}.{p.spelling}",
1088
+ "type": "variable",
1089
+ "dtype": p.type.spelling
1090
+ })
1091
+
1092
+ self._walk_cursor(c, source_file)
1093
+ continue
1094
+
1095
+ # --- Free functions ---
1096
+ if kind == cindex.CursorKind.FUNCTION_DECL:
1097
+ if c.spelling: # Only process if it has a name
1098
+ full_name = self._qualified(c.spelling)
1099
+ self.declared_entities.append({"name": full_name, "type": "function"})
1100
+ for p in c.get_arguments():
1101
+ if p.spelling: # Only add parameters with names
1102
+ self.declared_entities.append({
1103
+ "name": f"{full_name}.{p.spelling}",
1104
+ "type": "variable",
1105
+ "dtype": p.type.spelling
1106
+ })
1107
+ self._walk_cursor(c, source_file)
1108
+ continue
1109
+
1110
+ # --- Variables ---
1111
+ if kind == cindex.CursorKind.VAR_DECL:
1112
+ full_name = self._qualified(c.spelling)
1113
+ self.declared_entities.append({
1114
+ "name": full_name,
1115
+ "type": "variable",
1116
+ "dtype": c.type.spelling
1117
+ })
1118
+
1119
+ # Look for TYPE_REF children which explicitly reference the type
1120
+ # This is more reliable than c.type.spelling when includes aren't resolved
1121
+ type_ref_found = False
1122
+ for child in c.get_children():
1123
+ if child.kind == cindex.CursorKind.TYPE_REF:
1124
+ # TYPE_REF.spelling gives us the fully qualified type name
1125
+ # It may have 'class ' or 'struct ' prefix, so strip it
1126
+ if child.spelling:
1127
+ type_name = child.spelling.replace('class ', '').replace('struct ', '').strip()
1128
+ if type_name:
1129
+ # TYPE_REF gives us the canonical name from the definition,
1130
+ # which includes namespace qualifiers if present.
1131
+ # We only add this canonical name and rely on alias resolution
1132
+ # to match unqualified usage (e.g., 'Calculator' -> 'math::Calculator')
1133
+ self.called_entities.append(type_name)
1134
+ type_ref_found = True
1135
+ break
1136
+
1137
+ # Fallback: use c.type.spelling if no TYPE_REF found
1138
+ # Note: c.type.spelling may give us the name as written in source code,
1139
+ # which could be unqualified even if it refers to a namespaced type
1140
+ if not type_ref_found and c.type.spelling:
1141
+ # Extract the base type name (remove const, &, *, etc.)
1142
+ type_name = c.type.spelling.strip()
1143
+ # Remove common qualifiers
1144
+ type_name = type_name.replace('const', '').replace('&', '').replace('*', '').strip()
1145
+ if type_name and not type_name in ['int', 'float', 'double', 'char', 'bool', 'void', 'long', 'short', 'unsigned', 'signed']:
1146
+ # Only add if not already added via TYPE_REF
1147
+ # c.type.spelling might give unqualified name even for namespaced types
1148
+ # We'll add it and let alias resolution handle it
1149
+ self.called_entities.append(type_name)
1150
+
1151
+ # --- Calls ---
1152
+ if kind == cindex.CursorKind.CALL_EXPR:
1153
+ callee = None
1154
+ for child in c.get_children():
1155
+ if child.kind in (cindex.CursorKind.DECL_REF_EXPR, cindex.CursorKind.MEMBER_REF_EXPR):
1156
+ callee = child.spelling
1157
+ break
1158
+ if callee:
1159
+ self.called_entities.append(callee)
1160
+ else:
1161
+ self.called_entities.append(c.displayname or c.spelling)
1162
+
1163
+ # Recurse
1164
+ self._walk_cursor(c, source_file)
1165
+
1166
+ def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
1167
+ self.reset()
1168
+
1169
+ # If file_path is provided, use it directly for better include resolution
1170
+ # Otherwise, create a temporary file
1171
+ tf_name = None
1172
+ temp_file = False
1173
+
1174
+ if file_path and os.path.exists(file_path):
1175
+ tf_name = file_path
1176
+ temp_file = False
1177
+ else:
1178
+ with tempfile.NamedTemporaryFile(suffix=".cpp", mode="w+", delete=False) as tf:
1179
+ tf_name = tf.name
1180
+ tf.write(code)
1181
+ tf.flush()
1182
+ temp_file = True
1183
+
1184
+ # Get the directory containing the file for include paths
1185
+ include_dir = os.path.dirname(tf_name) if tf_name else None
1186
+ args = ['-std=c++17', '-xc++']
1187
+ if include_dir:
1188
+ args.append(f'-I{include_dir}')
1189
+
1190
+ try:
1191
+ tu = self.index.parse(
1192
+ tf_name,
1193
+ args=args,
1194
+ options=cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD
1195
+ )
1196
+ except Exception as e:
1197
+ raise RuntimeError(f"libclang failed to parse C++ translation unit: {e}")
1198
+
1199
+ self._walk_cursor(tu.cursor, tf_name)
1200
+
1201
+ # Deduplicate
1202
+ seen_decl = set()
1203
+ unique_declared = []
1204
+ for e in self.declared_entities:
1205
+ key = (e.get("name"), e.get("type"), e.get("dtype", None))
1206
+ if key not in seen_decl:
1207
+ unique_declared.append(e)
1208
+ seen_decl.add(key)
1209
+
1210
+ unique_called = list(dict.fromkeys(self.called_entities))
1211
+
1212
+ # Only delete if we created a temp file
1213
+ if temp_file:
1214
+ try:
1215
+ os.unlink(tf_name)
1216
+ except Exception:
1217
+ pass
1218
+
1219
+ return unique_declared, unique_called
1220
+
1221
+
1222
+ class RustEntityExtractor(BaseASTEntityExtractor):
1223
+ """
1224
+ Extract declared and called entities from Rust code using tree-sitter.
1225
+ Handles structs, enums, traits, functions, methods, and modules.
1226
+ Also detects API endpoint definitions (Actix-web, Rocket, Axum, Warp).
1227
+ """
1228
+
1229
+ # HTTP method route macros for Rust web frameworks
1230
+ ROUTE_MACROS = {
1231
+ 'get', 'post', 'put', 'patch', 'delete', 'head', 'options', # Actix-web, Rocket
1232
+ 'Get', 'Post', 'Put', 'Patch', 'Delete', 'Head', 'Options', # Alternative casing
1233
+ }
1234
+
1235
+ # Route-related macros and functions
1236
+ ROUTE_PATTERNS = {
1237
+ 'route', # Generic route macro
1238
+ 'web::get', 'web::post', 'web::put', 'web::delete', # Actix-web with web::
1239
+ }
1240
+
1241
+ def __init__(self):
1242
+
1243
+ self.parser = Parser()
1244
+ self.parser.language = Language(ts_rust.language())
1245
+ self.reset()
1246
+
1247
+ def reset(self) -> None:
1248
+ self.declared_entities = []
1249
+ self.called_entities = []
1250
+ self.scope_stack = []
1251
+ self.api_endpoints: List[Dict[str, Any]] = [] # Track API endpoint definitions
1252
+
1253
+ def _qualified(self, name: str) -> str:
1254
+ """Return fully qualified name using current scope stack."""
1255
+ if not name:
1256
+ return ""
1257
+ if not self.scope_stack:
1258
+ return name
1259
+ return "::".join(self.scope_stack + [name])
1260
+
1261
+ def _get_node_text(self, node, code_bytes: bytes) -> str:
1262
+ """Extract text content of a node."""
1263
+ return code_bytes[node.start_byte:node.end_byte].decode('utf8')
1264
+
1265
+ def _extract_api_endpoint_from_attributes(self, node, code_bytes: bytes) -> Optional[Dict[str, Any]]:
1266
+ """
1267
+ Extract API endpoint information from Rust function attributes.
1268
+ Handles patterns like:
1269
+ - #[get("/users")] # Actix-web, Rocket
1270
+ - #[post("/users")] # Actix-web, Rocket
1271
+ - #[route("/users", method="GET")] # Generic route
1272
+
1273
+ Note: In tree-sitter Rust AST, attributes appear as PREVIOUS SIBLINGS
1274
+ of the function_item node, not as children.
1275
+ """
1276
+
1277
+
1278
+ # Get the parent node to access siblings
1279
+ parent = node.parent
1280
+ if not parent:
1281
+ return None
1282
+
1283
+ # Find the index of current node in parent's children
1284
+ node_index = None
1285
+ for i, child in enumerate(parent.children):
1286
+ if child == node:
1287
+ node_index = i
1288
+ break
1289
+
1290
+ if node_index is None:
1291
+ return None
1292
+
1293
+ # Look backwards through previous siblings for attribute_item nodes
1294
+ for i in range(node_index - 1, -1, -1):
1295
+ sibling = parent.children[i]
1296
+
1297
+ # Stop if we hit a non-attribute node (except comments/whitespace)
1298
+ if sibling.type not in ['attribute_item', 'line_comment', 'block_comment']:
1299
+ break
1300
+
1301
+ if sibling.type == 'attribute_item':
1302
+ attr_text = self._get_node_text(sibling, code_bytes)
1303
+
1304
+ # Match HTTP method macros: #[get("/path")], #[post("/path")], #[post("/path", data = "<var>")], etc.
1305
+ # The pattern now allows optional additional parameters after the path
1306
+ method_pattern = r'#\[(get|post|put|patch|delete|head|options)\s*\(\s*"([^"]+)"(?:\s*,.*?)?\s*\)\]'
1307
+ match = re.search(method_pattern, attr_text, re.IGNORECASE)
1308
+
1309
+ if match:
1310
+ http_method = match.group(1).upper()
1311
+ endpoint_path = match.group(2)
1312
+ return {
1313
+ "endpoint": endpoint_path,
1314
+ "methods": [http_method],
1315
+ "type": "api_endpoint_definition"
1316
+ }
1317
+
1318
+ # Match generic route macro: #[route("/path", method="GET")]
1319
+ route_pattern = r'#\[route\s*\(\s*"([^"]+)"(?:.*?method\s*=\s*"([^"]+)")?\s*\)\]'
1320
+ match = re.search(route_pattern, attr_text, re.IGNORECASE)
1321
+
1322
+ if match:
1323
+ endpoint_path = match.group(1)
1324
+ http_method = match.group(2).upper() if match.group(2) else "GET"
1325
+ return {
1326
+ "endpoint": endpoint_path,
1327
+ "methods": [http_method],
1328
+ "type": "api_endpoint_definition"
1329
+ }
1330
+
1331
+ return None
1332
+
1333
+ def _walk_tree(self, node, code_bytes: bytes):
1334
+ """Recursively walk the tree-sitter AST."""
1335
+ node_type = node.type
1336
+
1337
+ # --- Module declarations ---
1338
+ if node_type == 'mod_item':
1339
+ # mod my_module { ... }
1340
+ name_node = node.child_by_field_name('name')
1341
+ if name_node:
1342
+ mod_name = self._get_node_text(name_node, code_bytes)
1343
+ qualified = self._qualified(mod_name)
1344
+ self.declared_entities.append({"name": qualified, "type": "module"})
1345
+
1346
+ self.scope_stack.append(mod_name)
1347
+ body = node.child_by_field_name('body')
1348
+ if body:
1349
+ for child in body.children:
1350
+ self._walk_tree(child, code_bytes)
1351
+ self.scope_stack.pop()
1352
+ return
1353
+
1354
+ # --- Struct declarations ---
1355
+ elif node_type == 'struct_item':
1356
+ name_node = node.child_by_field_name('name')
1357
+ if name_node:
1358
+ struct_name = self._get_node_text(name_node, code_bytes)
1359
+ qualified = self._qualified(struct_name)
1360
+ self.declared_entities.append({"name": qualified, "type": "struct"})
1361
+
1362
+ # Check for generic parameters
1363
+ type_params = node.child_by_field_name('type_parameters')
1364
+ if type_params:
1365
+ self._walk_tree(type_params, code_bytes)
1366
+
1367
+ self.scope_stack.append(struct_name)
1368
+ # Process fields
1369
+ body = node.child_by_field_name('body')
1370
+ if body:
1371
+ for child in body.children:
1372
+ if child.type == 'field_declaration':
1373
+ field_name_node = child.child_by_field_name('name')
1374
+ field_type_node = child.child_by_field_name('type')
1375
+ if field_name_node:
1376
+ field_name = self._get_node_text(field_name_node, code_bytes)
1377
+ field_type = self._get_node_text(field_type_node, code_bytes) if field_type_node else "unknown"
1378
+ self.declared_entities.append({
1379
+ "name": f"{qualified}.{field_name}",
1380
+ "type": "field",
1381
+ "dtype": field_type
1382
+ })
1383
+ self.scope_stack.pop()
1384
+ return
1385
+
1386
+ # --- Enum declarations ---
1387
+ elif node_type == 'enum_item':
1388
+ name_node = node.child_by_field_name('name')
1389
+ if name_node:
1390
+ enum_name = self._get_node_text(name_node, code_bytes)
1391
+ qualified = self._qualified(enum_name)
1392
+ self.declared_entities.append({"name": qualified, "type": "enum"})
1393
+
1394
+ self.scope_stack.append(enum_name)
1395
+ body = node.child_by_field_name('body')
1396
+ if body:
1397
+ for child in body.children:
1398
+ if child.type == 'enum_variant':
1399
+ variant_name_node = child.child_by_field_name('name')
1400
+ if variant_name_node:
1401
+ variant_name = self._get_node_text(variant_name_node, code_bytes)
1402
+ self.declared_entities.append({
1403
+ "name": f"{qualified}::{variant_name}",
1404
+ "type": "enum_variant"
1405
+ })
1406
+ self.scope_stack.pop()
1407
+ return
1408
+
1409
+ # --- Trait declarations ---
1410
+ elif node_type == 'trait_item':
1411
+ name_node = node.child_by_field_name('name')
1412
+ if name_node:
1413
+ trait_name = self._get_node_text(name_node, code_bytes)
1414
+ qualified = self._qualified(trait_name)
1415
+ self.declared_entities.append({"name": qualified, "type": "trait"})
1416
+
1417
+ self.scope_stack.append(trait_name)
1418
+ body = node.child_by_field_name('body')
1419
+ if body:
1420
+ for child in body.children:
1421
+ self._walk_tree(child, code_bytes)
1422
+ self.scope_stack.pop()
1423
+ return
1424
+
1425
+ # --- Implementation blocks ---
1426
+ elif node_type == 'impl_item':
1427
+ # impl MyStruct { ... } or impl Trait for MyStruct { ... }
1428
+ type_node = node.child_by_field_name('type')
1429
+ trait_node = node.child_by_field_name('trait')
1430
+
1431
+ impl_name = None
1432
+ if type_node:
1433
+ impl_name = self._get_node_text(type_node, code_bytes)
1434
+
1435
+ if trait_node:
1436
+ trait_name = self._get_node_text(trait_node, code_bytes)
1437
+ self.called_entities.append(trait_name)
1438
+
1439
+ if impl_name:
1440
+ self.scope_stack.append(impl_name)
1441
+
1442
+ body = node.child_by_field_name('body')
1443
+ if body:
1444
+ for child in body.children:
1445
+ self._walk_tree(child, code_bytes)
1446
+
1447
+ if impl_name:
1448
+ self.scope_stack.pop()
1449
+ return
1450
+
1451
+ # --- Function declarations ---
1452
+ elif node_type == 'function_item':
1453
+ name_node = node.child_by_field_name('name')
1454
+ if name_node:
1455
+ func_name = self._get_node_text(name_node, code_bytes)
1456
+ qualified = self._qualified(func_name)
1457
+
1458
+ # Check for API endpoint attributes (e.g., #[get("/users")])
1459
+ api_info = self._extract_api_endpoint_from_attributes(node, code_bytes)
1460
+
1461
+ if api_info:
1462
+ # This is an API endpoint handler
1463
+ self.declared_entities.append({
1464
+ "name": qualified,
1465
+ "type": "api_endpoint",
1466
+ "endpoint": api_info.get("endpoint"),
1467
+ "methods": api_info.get("methods")
1468
+ })
1469
+ self.api_endpoints.append({**api_info, "function": qualified})
1470
+ entity_type = "api_endpoint"
1471
+ else:
1472
+ # Determine if this is a method (inside impl block) or free function
1473
+ entity_type = "method" if len(self.scope_stack) > 0 else "function"
1474
+ self.declared_entities.append({"name": qualified, "type": entity_type})
1475
+
1476
+ # Extract parameters
1477
+ params = node.child_by_field_name('parameters')
1478
+ if params:
1479
+ for child in params.children:
1480
+ if child.type == 'parameter':
1481
+ pattern = child.child_by_field_name('pattern')
1482
+ type_node = child.child_by_field_name('type')
1483
+ if pattern:
1484
+ param_name = self._get_node_text(pattern, code_bytes)
1485
+ param_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown"
1486
+ # Skip 'self' parameters
1487
+ if param_name not in ['self', '&self', '&mut self', 'mut self']:
1488
+ self.declared_entities.append({
1489
+ "name": f"{qualified}.{param_name}",
1490
+ "type": "variable",
1491
+ "dtype": param_type
1492
+ })
1493
+
1494
+ # Walk the function body to find calls
1495
+ body = node.child_by_field_name('body')
1496
+ if body:
1497
+ self._walk_tree(body, code_bytes)
1498
+ return
1499
+
1500
+ # --- Type alias ---
1501
+ elif node_type == 'type_item':
1502
+ name_node = node.child_by_field_name('name')
1503
+ if name_node:
1504
+ type_name = self._get_node_text(name_node, code_bytes)
1505
+ qualified = self._qualified(type_name)
1506
+ self.declared_entities.append({"name": qualified, "type": "type_alias"})
1507
+ return
1508
+
1509
+ # --- Constant declarations ---
1510
+ elif node_type == 'const_item':
1511
+ name_node = node.child_by_field_name('name')
1512
+ type_node = node.child_by_field_name('type')
1513
+ if name_node:
1514
+ const_name = self._get_node_text(name_node, code_bytes)
1515
+ const_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown"
1516
+ qualified = self._qualified(const_name)
1517
+ self.declared_entities.append({
1518
+ "name": qualified,
1519
+ "type": "constant",
1520
+ "dtype": const_type
1521
+ })
1522
+
1523
+ # --- Static declarations ---
1524
+ elif node_type == 'static_item':
1525
+ name_node = node.child_by_field_name('name')
1526
+ type_node = node.child_by_field_name('type')
1527
+ if name_node:
1528
+ static_name = self._get_node_text(name_node, code_bytes)
1529
+ static_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown"
1530
+ qualified = self._qualified(static_name)
1531
+ self.declared_entities.append({
1532
+ "name": qualified,
1533
+ "type": "static",
1534
+ "dtype": static_type
1535
+ })
1536
+
1537
+ # --- Let bindings (local variables) ---
1538
+ elif node_type == 'let_declaration':
1539
+ pattern = node.child_by_field_name('pattern')
1540
+ type_node = node.child_by_field_name('type')
1541
+ if pattern and pattern.type == 'identifier':
1542
+ var_name = self._get_node_text(pattern, code_bytes)
1543
+ var_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown"
1544
+ # Only track top-level or module-level variables, not function-local ones
1545
+ # For now, we skip local variables to avoid clutter
1546
+
1547
+ # --- Use declarations (imports) ---
1548
+ elif node_type == 'use_declaration':
1549
+ # Extract imported items
1550
+ use_text = self._get_node_text(node, code_bytes)
1551
+ self.called_entities.append(use_text)
1552
+
1553
+ # --- Call expressions ---
1554
+ elif node_type == 'call_expression':
1555
+ function = node.child_by_field_name('function')
1556
+ if function:
1557
+ func_text = self._get_node_text(function, code_bytes)
1558
+ # Clean up function call to get just the name/path
1559
+ # Handle method calls like obj.method() and path calls like std::vec::Vec::new()
1560
+ self.called_entities.append(func_text)
1561
+
1562
+ # --- Macro invocations ---
1563
+ elif node_type == 'macro_invocation':
1564
+ macro_node = node.child_by_field_name('macro')
1565
+ if macro_node:
1566
+ macro_name = self._get_node_text(macro_node, code_bytes)
1567
+ self.called_entities.append(f"{macro_name}!")
1568
+
1569
+ # --- Field expressions (method calls or field access) ---
1570
+ elif node_type == 'field_expression':
1571
+ field = node.child_by_field_name('field')
1572
+ if field:
1573
+ field_name = self._get_node_text(field, code_bytes)
1574
+ # This could be a field access or method call, record it
1575
+ # We don't have full context here, so just record the field name
1576
+
1577
+ # Recursively walk all children
1578
+ for child in node.children:
1579
+ self._walk_tree(child, code_bytes)
1580
+
1581
+ def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
1582
+ """Extract entities from Rust code using tree-sitter."""
1583
+ self.reset()
1584
+
1585
+ code_bytes = code.encode('utf8')
1586
+ tree = self.parser.parse(code_bytes)
1587
+
1588
+ # Walk the AST
1589
+ self._walk_tree(tree.root_node, code_bytes)
1590
+
1591
+ # Deduplicate
1592
+ seen_decl = set()
1593
+ unique_declared = []
1594
+ for e in self.declared_entities:
1595
+ key = (e.get("name"), e.get("type"), e.get("dtype", None))
1596
+ if key not in seen_decl:
1597
+ unique_declared.append(e)
1598
+ seen_decl.add(key)
1599
+
1600
+ unique_called = list(dict.fromkeys(self.called_entities))
1601
+
1602
+ return unique_declared, unique_called
1603
+
1604
+
1605
+ class PythonASTEntityExtractor(ast.NodeVisitor, BaseASTEntityExtractor):
1606
+ """
1607
+ AST-based entity extractor for Python code.
1608
+ Also detects API endpoint definitions (FastAPI, Flask, Django REST Framework).
1609
+ """
1610
+
1611
+ # Common HTTP decorators/patterns for Python web frameworks
1612
+ API_DECORATORS = {
1613
+ 'route', # Flask @app.route
1614
+ 'get', 'post', 'put', 'patch', 'delete', 'head', 'options', # FastAPI/Flask methods
1615
+ 'api_view', # DRF @api_view
1616
+ }
1617
+
1618
+ def __init__(self):
1619
+ self.declared_entities: List[Dict[str, Any]] = []
1620
+ self.called_entities: List[str] = []
1621
+ self.current_class: Optional[str] = None
1622
+ self.current_function: Optional[str] = None
1623
+ self.api_endpoints: List[Dict[str, Any]] = [] # Track API endpoint definitions
1624
+
1625
+ def reset(self) -> None:
1626
+ """Clear previous extraction state including context"""
1627
+ self.declared_entities = []
1628
+ self.called_entities = []
1629
+ self.current_class = None
1630
+ self.current_function = None
1631
+ self.api_endpoints = []
1632
+
1633
+ def _get_type_annotation(self, node: ast.AST) -> str:
1634
+ """Extract type annotation from AST node"""
1635
+ if isinstance(node, ast.Name):
1636
+ return node.id
1637
+ elif isinstance(node, ast.Constant):
1638
+ return type(node.value).__name__
1639
+ elif isinstance(node, ast.Attribute):
1640
+ return f"{self._get_type_annotation(node.value)}.{node.attr}"
1641
+ elif isinstance(node, ast.Subscript):
1642
+ # Handle generic types like List[str], Dict[str, int]
1643
+ base = self._get_type_annotation(node.value)
1644
+ if isinstance(node.slice, ast.Tuple):
1645
+ args = [self._get_type_annotation(elt) for elt in node.slice.elts]
1646
+ return f"{base}[{', '.join(args)}]"
1647
+ else:
1648
+ arg = self._get_type_annotation(node.slice)
1649
+ return f"{base}[{arg}]"
1650
+ return "unknown"
1651
+
1652
+ def _infer_type_from_value(self, node: ast.AST) -> str:
1653
+ """Infer type from assigned value"""
1654
+ if isinstance(node, ast.Constant):
1655
+ return type(node.value).__name__
1656
+ elif isinstance(node, ast.List):
1657
+ return "list"
1658
+ elif isinstance(node, ast.Dict):
1659
+ return "dict"
1660
+ elif isinstance(node, ast.Set):
1661
+ return "set"
1662
+ elif isinstance(node, ast.Tuple):
1663
+ return "tuple"
1664
+ elif isinstance(node, ast.Call):
1665
+ if isinstance(node.func, ast.Name):
1666
+ return node.func.id # Constructor call
1667
+ elif isinstance(node.func, ast.Attribute):
1668
+ return "unknown"
1669
+ elif isinstance(node, ast.Name):
1670
+ return "unknown" # Reference to another variable
1671
+ return "unknown"
1672
+
1673
+ def visit_ClassDef(self, node: ast.ClassDef):
1674
+ """Visit class definitions"""
1675
+ old_class = self.current_class
1676
+ self.current_class = node.name
1677
+
1678
+ # Add class to declared entities
1679
+ self.declared_entities.append({
1680
+ "name": node.name,
1681
+ "type": "class"
1682
+ })
1683
+
1684
+ # Record base classes as called entities
1685
+ for base in node.bases:
1686
+ if isinstance(base, ast.Name):
1687
+ self.called_entities.append(base.id)
1688
+ elif isinstance(base, ast.Attribute):
1689
+ self.called_entities.append(self._get_type_annotation(base))
1690
+
1691
+ # Continue visiting child nodes
1692
+ self.generic_visit(node)
1693
+ self.current_class = old_class
1694
+
1695
+ def visit_FunctionDef(self, node: ast.FunctionDef):
1696
+ """Visit function/method definitions and detect API endpoints"""
1697
+ old_function = self.current_function
1698
+
1699
+ if self.current_class:
1700
+ # This is a method
1701
+ full_name = f"{self.current_class}.{node.name}"
1702
+ entity_type = "method"
1703
+ else:
1704
+ # This is a function
1705
+ full_name = node.name
1706
+ entity_type = "function"
1707
+
1708
+ self.current_function = full_name
1709
+
1710
+ # Check for API endpoint decorators
1711
+ api_info = self._extract_api_endpoint_from_decorators(node.decorator_list, full_name)
1712
+ if api_info:
1713
+ # Mark this as an API endpoint
1714
+ self.declared_entities.append({
1715
+ "name": full_name,
1716
+ "type": "api_endpoint",
1717
+ "endpoint": api_info.get("endpoint"),
1718
+ "methods": api_info.get("methods")
1719
+ })
1720
+ self.api_endpoints.append(api_info)
1721
+ else:
1722
+ self.declared_entities.append({
1723
+ "name": full_name,
1724
+ "type": entity_type
1725
+ })
1726
+
1727
+ # Process parameters
1728
+ for arg in node.args.args:
1729
+ if arg.arg == 'self' and self.current_class:
1730
+ continue # Skip self parameter
1731
+
1732
+ dtype = "unknown"
1733
+ if arg.annotation:
1734
+ dtype = self._get_type_annotation(arg.annotation)
1735
+
1736
+ param_name = f"{full_name}.{arg.arg}" if entity_type == "method" else arg.arg
1737
+ self.declared_entities.append({
1738
+ "name": param_name,
1739
+ "type": "variable",
1740
+ "dtype": dtype
1741
+ })
1742
+
1743
+ # Continue visiting child nodes
1744
+ self.generic_visit(node)
1745
+ self.current_function = old_function
1746
+
1747
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
1748
+ """Visit async function/method definitions"""
1749
+ # Treat async functions the same as regular functions
1750
+ self.visit_FunctionDef(node)
1751
+
1752
+ def visit_Assign(self, node: ast.Assign):
1753
+ """Visit assignment statements"""
1754
+ # Infer type from the assigned value
1755
+ dtype = self._infer_type_from_value(node.value)
1756
+
1757
+ for target in node.targets:
1758
+ if isinstance(target, ast.Name):
1759
+ # Simple variable assignment
1760
+ var_name = target.id
1761
+ if self.current_class and self.current_function and self.current_function.startswith(self.current_class):
1762
+ # Local variable in method
1763
+ pass # Could add local variables if needed
1764
+ else:
1765
+ # Module-level variable
1766
+ self.declared_entities.append({
1767
+ "name": var_name,
1768
+ "type": "variable",
1769
+ "dtype": dtype
1770
+ })
1771
+
1772
+ elif isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name):
1773
+ # Attribute assignment like self.name = value
1774
+ if target.value.id == 'self' and self.current_class:
1775
+ attr_name = f"{self.current_class}.{target.attr}"
1776
+ self.declared_entities.append({
1777
+ "name": attr_name,
1778
+ "type": "variable",
1779
+ "dtype": dtype
1780
+ })
1781
+
1782
+ # Continue visiting to catch function calls in the assignment
1783
+ self.generic_visit(node)
1784
+
1785
+ def visit_AnnAssign(self, node: ast.AnnAssign):
1786
+ """Visit annotated assignment statements (PEP 526)"""
1787
+ if isinstance(node.target, ast.Name):
1788
+ dtype = self._get_type_annotation(node.annotation)
1789
+ var_name = node.target.id
1790
+
1791
+ self.declared_entities.append({
1792
+ "name": var_name,
1793
+ "type": "variable",
1794
+ "dtype": dtype
1795
+ })
1796
+
1797
+ elif isinstance(node.target, ast.Attribute) and isinstance(node.target.value, ast.Name):
1798
+ if node.target.value.id == 'self' and self.current_class:
1799
+ dtype = self._get_type_annotation(node.annotation)
1800
+ attr_name = f"{self.current_class}.{node.target.attr}"
1801
+ self.declared_entities.append({
1802
+ "name": attr_name,
1803
+ "type": "variable",
1804
+ "dtype": dtype
1805
+ })
1806
+
1807
+ # Continue visiting
1808
+ if node.value:
1809
+ self.generic_visit(node)
1810
+
1811
+ def visit_Import(self, node: ast.Import):
1812
+ """Visit import statements"""
1813
+ for alias in node.names:
1814
+ # Record the imported module/package
1815
+ self.called_entities.append(alias.name)
1816
+ self.generic_visit(node)
1817
+
1818
+ def visit_ImportFrom(self, node: ast.ImportFrom):
1819
+ """Visit from...import statements"""
1820
+ if node.module:
1821
+ # Record the module being imported from
1822
+ self.called_entities.append(node.module)
1823
+ # Optionally, also record specific imports as module.name
1824
+ for alias in node.names:
1825
+ if alias.name != '*':
1826
+ self.called_entities.append(f"{node.module}.{alias.name}")
1827
+ else:
1828
+ # Relative imports without module (from . import x)
1829
+ for alias in node.names:
1830
+ if alias.name != '*':
1831
+ self.called_entities.append(alias.name)
1832
+ self.generic_visit(node)
1833
+
1834
+ def visit_Call(self, node: ast.Call):
1835
+ """Visit function/method calls"""
1836
+ if isinstance(node.func, ast.Name):
1837
+ # Simple function call
1838
+ self.called_entities.append(node.func.id)
1839
+
1840
+ elif isinstance(node.func, ast.Attribute):
1841
+ # Method call or attribute access
1842
+ if isinstance(node.func.value, ast.Name):
1843
+ # obj.method() - we need to infer the class of obj
1844
+ # For now, just record the method name
1845
+ method_name = node.func.attr
1846
+ # Try to find the variable type from our declared entities
1847
+ obj_name = node.func.value.id
1848
+ obj_class = self._find_variable_type(obj_name)
1849
+ if obj_class and obj_class != "unknown":
1850
+ self.called_entities.append(f"{obj_class}.{method_name}")
1851
+ else:
1852
+ # Fallback: just record the method call
1853
+ self.called_entities.append(method_name)
1854
+
1855
+ elif isinstance(node.func.value, ast.Attribute):
1856
+ # Nested attribute access like module.Class.method()
1857
+ full_name = self._get_type_annotation(node.func)
1858
+ self.called_entities.append(full_name)
1859
+
1860
+ # Continue visiting child nodes
1861
+ self.generic_visit(node)
1862
+
1863
+ def _find_variable_type(self, var_name: str) -> str:
1864
+ """Find the type of a variable from declared entities"""
1865
+ for entity in self.declared_entities:
1866
+ if entity["name"] == var_name and entity["type"] == "variable":
1867
+ return entity.get("dtype", "unknown")
1868
+ return "unknown"
1869
+
1870
+ def _extract_api_endpoint_from_decorators(self, decorators: List[ast.expr], function_name: str) -> Optional[Dict[str, Any]]:
1871
+ """
1872
+ Extract API endpoint information from function decorators.
1873
+ Handles patterns like:
1874
+ - @app.route("/api/users", methods=["GET", "POST"]) # Flask
1875
+ - @app.get("/api/users") # FastAPI
1876
+ - @router.post("/api/users") # FastAPI with router
1877
+ - @api_view(['GET', 'POST']) # Django REST Framework
1878
+ """
1879
+ for decorator in decorators:
1880
+ # Handle @app.route(...) or @app.get(...)
1881
+ if isinstance(decorator, ast.Call):
1882
+ if isinstance(decorator.func, ast.Attribute):
1883
+ # e.g., app.route, app.get, router.post
1884
+ method_name = decorator.func.attr.lower()
1885
+
1886
+ if method_name in self.API_DECORATORS:
1887
+ endpoint = None
1888
+ http_methods = []
1889
+
1890
+ # Extract endpoint from first positional argument
1891
+ if decorator.args and isinstance(decorator.args[0], ast.Constant):
1892
+ endpoint = decorator.args[0].value
1893
+
1894
+ # For FastAPI-style decorators (@app.get, @app.post)
1895
+ if method_name in {'get', 'post', 'put', 'patch', 'delete', 'head', 'options'}:
1896
+ http_methods = [method_name.upper()]
1897
+
1898
+ # For Flask-style @app.route with methods kwarg
1899
+ elif method_name == 'route':
1900
+ for keyword in decorator.keywords:
1901
+ if keyword.arg == 'methods':
1902
+ if isinstance(keyword.value, ast.List):
1903
+ http_methods = [
1904
+ elt.value for elt in keyword.value.elts
1905
+ if isinstance(elt, ast.Constant) and isinstance(elt.value, str)
1906
+ ]
1907
+ if not http_methods:
1908
+ http_methods = ['GET'] # Flask default
1909
+
1910
+ # For DRF @api_view(['GET', 'POST'])
1911
+ elif method_name == 'api_view':
1912
+ if decorator.args and isinstance(decorator.args[0], ast.List):
1913
+ http_methods = [
1914
+ elt.value for elt in decorator.args[0].elts
1915
+ if isinstance(elt, ast.Constant) and isinstance(elt.value, str)
1916
+ ]
1917
+
1918
+ if endpoint:
1919
+ return {
1920
+ "function": function_name,
1921
+ "endpoint": endpoint,
1922
+ "methods": http_methods,
1923
+ "type": "api_endpoint_definition"
1924
+ }
1925
+
1926
+ return None
1927
+
1928
+ def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
1929
+ """
1930
+ Extract entities from Python code using AST parsing
1931
+
1932
+ Args:
1933
+ code: Python source code as string
1934
+ file_path: Optional path to the source file (for context)
1935
+
1936
+ Returns:
1937
+ Tuple of (declared_entities, called_entities)
1938
+ """
1939
+ # Ensure fresh state on each extraction
1940
+ self.reset()
1941
+
1942
+ try:
1943
+ tree = ast.parse(code)
1944
+ self.visit(tree)
1945
+
1946
+ # Remove duplicates while preserving order
1947
+ seen_declared = set()
1948
+ unique_declared = []
1949
+ for entity in self.declared_entities:
1950
+ key = (entity["name"], entity["type"], entity.get("dtype"))
1951
+ if key not in seen_declared:
1952
+ unique_declared.append(entity)
1953
+ seen_declared.add(key)
1954
+
1955
+ unique_called = list(dict.fromkeys(self.called_entities)) # Remove duplicates
1956
+
1957
+ return unique_declared, unique_called
1958
+
1959
+ except SyntaxError as e:
1960
+ logger.error(f"Syntax error in Python code: {e}")
1961
+ return [], []
1962
+ except Exception as e:
1963
+ logger.error(f"Error parsing Python code: {e}", exc_info=True)
1964
+ return [], []
1965
+
1966
+
1967
+ class HybridEntityExtractor:
1968
+ """
1969
+ Hybrid entity extractor that uses AST for known languages,
1970
+ falls back to LLM for unknown ones
1971
+ """
1972
+
1973
+ def __init__(self):
1974
+ self.extractors = {
1975
+ 'py': PythonASTEntityExtractor(),
1976
+ 'c': CEntityExtractor(),
1977
+ 'h': CppEntityExtractor(), # C/C++ headers
1978
+ 'cpp': CppEntityExtractor(),
1979
+ 'cc': CppEntityExtractor(),
1980
+ 'cxx': CppEntityExtractor(),
1981
+ 'hpp': CppEntityExtractor(),
1982
+ 'hxx': CppEntityExtractor(),
1983
+ 'hh': CppEntityExtractor(),
1984
+ 'java': JavaEntityExtractor(),
1985
+ 'js': JavaScriptEntityExtractor(), # βœ… NEW
1986
+ 'jsx': JavaScriptEntityExtractor(), # βœ… NEW
1987
+ 'ts': JavaScriptEntityExtractor(), # TypeScript uses similar AST
1988
+ 'tsx': JavaScriptEntityExtractor(), # TSX similar to JSX
1989
+ 'rs': RustEntityExtractor(),
1990
+ 'html': HTMLEntityExtractor()
1991
+ }
1992
+
1993
+ def _get_language_from_filename(self, file_name: str) -> str:
1994
+ ext = file_name.split('.')[-1].lower()
1995
+ return ext
1996
+
1997
+ def extract_entities(self, code: str, file_name: str):
1998
+
1999
+ lang = self._get_language_from_filename(file_name)
2000
+ extractor = self.extractors.get(lang)
2001
+
2002
+ if extractor:
2003
+ # Reset the shared extractor instance to ensure no state is carried over
2004
+ try:
2005
+ extractor.reset()
2006
+ except Exception:
2007
+ # If extractor doesn't implement reset for some reason, ignore and proceed
2008
+ pass
2009
+
2010
+ logger.info(f"Using AST extraction for {lang.upper()} file: {file_name}")
2011
+ try:
2012
+ # Try to pass file_name if the extractor supports it (C++ extractor does)
2013
+ try:
2014
+ declared_entities, called_entities = extractor.extract_entities(code, file_path=file_name)
2015
+ except TypeError:
2016
+ # Fallback for extractors that don't accept file_path parameter
2017
+ declared_entities, called_entities = extractor.extract_entities(code)
2018
+
2019
+ # Add aliases to each declared entity based on file path
2020
+ for entity in declared_entities:
2021
+ entity_name = entity.get('name', '')
2022
+ if entity_name:
2023
+ aliases = generate_entity_aliases(entity_name, file_name)
2024
+ entity['aliases'] = aliases
2025
+ logger.debug(f"Generated aliases for entity '{entity_name}': {aliases}")
2026
+
2027
+ return declared_entities, called_entities
2028
+ except Exception as e:
2029
+ logger.error(f"Error during AST extraction for file {file_name}: {e}", exc_info=True)
2030
+ return [], []
2031
+ else:
2032
+ raise Exception(f"Using LLM extraction for unsupported language: {file_name}")
RepoKnowledgeGraphLib/KnowledgeGraphMCPServer.py ADDED
@@ -0,0 +1,1107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Annotated
3
+ from fastmcp import FastMCP
4
+ from langfuse import get_client, observe
5
+
6
+ from .RepoKnowledgeGraph import RepoKnowledgeGraph
7
+
8
+
9
+ # Custom Exceptions
10
+ class MCPServerError(Exception):
11
+ """Base exception for MCP server errors"""
12
+ pass
13
+
14
+
15
+ class NodeNotFoundError(MCPServerError):
16
+ """Raised when a node is not found"""
17
+ pass
18
+
19
+
20
+ class EntityNotFoundError(MCPServerError):
21
+ """Raised when an entity is not found"""
22
+ pass
23
+
24
+
25
+ class InvalidInputError(MCPServerError):
26
+ """Raised when input validation fails"""
27
+ pass
28
+
29
+
30
+ class KnowledgeGraphMCPServer:
31
+ """
32
+ MCP Server for interacting with a codebase knowledge graph.
33
+
34
+ Attributes:
35
+ knowledge_graph (RepoKnowledgeGraph): The loaded knowledge graph object.
36
+ app (FastMCP): The FastMCP application instance for tool registration and serving.
37
+ """
38
+ def __init__(self, knowledge_graph: Optional[RepoKnowledgeGraph] = None, knowledge_graph_path: Optional[str] = None, server_name: str = "knowledge-graph-mcp-server"):
39
+ if knowledge_graph is not None:
40
+ self.knowledge_graph = knowledge_graph
41
+ else:
42
+ if knowledge_graph_path is None:
43
+ knowledge_graph_path = os.path.join(os.path.dirname(__file__), "knowledge_graph.json")
44
+ self.knowledge_graph = RepoKnowledgeGraph.load_graph_from_file(knowledge_graph_path)
45
+ self.langfuse = get_client()
46
+ self.app = FastMCP(server_name)
47
+ self.register_tools()
48
+
49
+ def _validate_node_exists(self, node_id: str) -> bool:
50
+ """Centralized node validation"""
51
+ if node_id not in self.knowledge_graph.graph:
52
+ raise NodeNotFoundError(f"Node '{node_id}' not found in knowledge graph")
53
+ return True
54
+
55
+ def _validate_entity_exists(self, entity_name: str) -> bool:
56
+ """Centralized entity validation"""
57
+ if entity_name not in self.knowledge_graph.entities:
58
+ raise EntityNotFoundError(f"Entity '{entity_name}' not found in knowledge graph")
59
+ return True
60
+
61
+ def _validate_positive_int(self, value: int, param_name: str) -> bool:
62
+ """Validate that an integer parameter is positive"""
63
+ if value <= 0:
64
+ raise InvalidInputError(f"{param_name} must be a positive integer, got {value}")
65
+ return True
66
+
67
+ def _sanitize_chunk_dict(self, chunk_dict: dict) -> dict:
68
+ """Remove embedding data from chunk dictionary before returning to user"""
69
+ sanitized = chunk_dict.copy()
70
+ sanitized.pop('embedding', None)
71
+ return sanitized
72
+
73
+ def _sanitize_node_dict(self, node_dict: dict) -> dict:
74
+ """Remove embedding data from node dictionary before returning to user"""
75
+ sanitized = node_dict.copy()
76
+ if 'data' in sanitized and isinstance(sanitized['data'], dict):
77
+ sanitized['data'] = sanitized['data'].copy()
78
+ sanitized['data'].pop('embedding', None)
79
+ sanitized.pop('embedding', None)
80
+ return sanitized
81
+
82
+ def _handle_error(self, error: Exception, context: str = "") -> dict:
83
+ """Centralized error handling with structured response"""
84
+ if isinstance(error, NodeNotFoundError):
85
+ return {
86
+ "error": str(error),
87
+ "error_type": "node_not_found",
88
+ "context": context
89
+ }
90
+ elif isinstance(error, EntityNotFoundError):
91
+ return {
92
+ "error": str(error),
93
+ "error_type": "entity_not_found",
94
+ "context": context
95
+ }
96
+ elif isinstance(error, InvalidInputError):
97
+ return {
98
+ "error": str(error),
99
+ "error_type": "invalid_input",
100
+ "context": context
101
+ }
102
+ else:
103
+ return {
104
+ "error": str(error),
105
+ "error_type": "internal_error",
106
+ "context": context
107
+ }
108
+
109
+ @classmethod
110
+ def from_path(cls, path: str, skip_dirs=None, index_nodes=True, describe_nodes=False, extract_entities=False, model_service_kwargs=None, code_index_kwargs=None, server_name: str = "knowledge-graph-mcp-server"):
111
+ """
112
+ Build a KnowledgeGraphMCPServer from a code repository path.
113
+ """
114
+ if skip_dirs is None:
115
+ skip_dirs = []
116
+ if model_service_kwargs is None:
117
+ model_service_kwargs = {}
118
+ kg = RepoKnowledgeGraph.from_path(path, skip_dirs=skip_dirs, index_nodes=index_nodes, describe_nodes=describe_nodes, extract_entities=extract_entities, model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs)
119
+ return cls(knowledge_graph=kg, server_name=server_name)
120
+
121
+ @classmethod
122
+ def from_file(cls, filepath: str, index_nodes=True, use_embed=True, model_service_kwargs=None, code_index_kwargs = None, server_name: str = "knowledge-graph-mcp-server"):
123
+ """
124
+ Build a KnowledgeGraphMCPServer from a serialized knowledge graph file.
125
+ """
126
+ if model_service_kwargs is None:
127
+ model_service_kwargs = {}
128
+ kg = RepoKnowledgeGraph.load_graph_from_file(filepath, index_nodes=index_nodes, use_embed=use_embed, model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs)
129
+ return cls(knowledge_graph=kg, server_name=server_name)
130
+
131
+ @classmethod
132
+ def from_repo(cls, repo_url: str, index_nodes=True, describe_nodes=False, model_service_kwargs=None, code_index_kwargs=None, server_name: str = "knowledge-graph-mcp-server", github_token=None, allow_unauthenticated_clone=True, skip_dirs=None, extract_entities=True):
133
+ if model_service_kwargs is None:
134
+ model_service_kwargs = {}
135
+ kg = RepoKnowledgeGraph.from_repo(repo_url=repo_url, describe_nodes=describe_nodes, index_nodes=index_nodes, model_service_kwargs=model_service_kwargs, github_token=github_token, allow_unauthenticated_clone=allow_unauthenticated_clone, skip_dirs=skip_dirs, extract_entities=extract_entities, code_index_kwargs=code_index_kwargs)
136
+ return cls(knowledge_graph=kg, server_name=server_name)
137
+
138
+
139
+ def register_tools(self):
140
+ @self.app.tool(
141
+ description="Get detailed information about a node in the knowledge graph, including its type, name, description, declared and called entities, and a content preview."
142
+ )
143
+ @observe(as_type='tool')
144
+ async def get_node_info(
145
+ node_id: Annotated[str, "The ID of the node to retrieve information for."]
146
+ ) -> dict:
147
+ try:
148
+ self._validate_node_exists(node_id)
149
+ node = self.knowledge_graph.graph.nodes[node_id]['data']
150
+
151
+ declared_entities = getattr(node, 'declared_entities', [])
152
+ called_entities = getattr(node, 'called_entities', [])
153
+ content = getattr(node, 'content', None)
154
+ content_preview = content[:200] + "..." if content and len(content) > 200 else content
155
+
156
+ return {
157
+ "node_id": node_id,
158
+ "class": node.__class__.__name__,
159
+ "name": getattr(node, 'name', 'Unknown'),
160
+ "type": getattr(node, 'node_type', 'Unknown'),
161
+ "description": getattr(node, 'description', None),
162
+ "declared_entities": declared_entities,
163
+ "called_entities": called_entities,
164
+ "content_preview": content_preview,
165
+ "text": f"Node {node_id} ({getattr(node, 'name', '?')}) β€” {getattr(node, 'node_type', '?')} with {len(declared_entities)} declared and {len(called_entities)} called entities."
166
+ }
167
+ except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
168
+ return self._handle_error(e, "get_node_info")
169
+ except Exception as e:
170
+ return self._handle_error(e, "get_node_info")
171
+
172
+ @self.app.tool(
173
+ description="List all incoming and outgoing edges for a node, showing relationships to other nodes."
174
+ )
175
+ @observe(as_type='tool')
176
+ async def get_node_edges(
177
+ node_id: Annotated[str, "The ID of the node whose edges to list."]
178
+ ) -> dict:
179
+ try:
180
+ self._validate_node_exists(node_id)
181
+ g = self.knowledge_graph.graph
182
+
183
+ incoming = [
184
+ {"source": src, "target": tgt, "relation": data.get("relation", "?")}
185
+ for src, tgt, data in g.in_edges(node_id, data=True)
186
+ ]
187
+ outgoing = [
188
+ {"source": src, "target": tgt, "relation": data.get("relation", "?")}
189
+ for src, tgt, data in g.out_edges(node_id, data=True)
190
+ ]
191
+
192
+ return {
193
+ "node_id": node_id,
194
+ "incoming": incoming,
195
+ "outgoing": outgoing,
196
+ "incoming_count": len(incoming),
197
+ "outgoing_count": len(outgoing),
198
+ "text": f"Node '{node_id}' has {len(incoming)} incoming and {len(outgoing)} outgoing edges."
199
+ }
200
+ except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
201
+ return self._handle_error(e, "get_node_edges")
202
+ except Exception as e:
203
+ return self._handle_error(e, "get_node_edges")
204
+
205
+ @self.app.tool(
206
+ description="Search for nodes in the knowledge graph by query string, using the code index semantic and keyword search."
207
+ )
208
+ @observe(as_type='tool')
209
+ async def search_nodes(
210
+ query: Annotated[str, "The search string to match against code index."],
211
+ limit: Annotated[int, "Maximum number of results to return."] = 10
212
+ ) -> dict:
213
+ try:
214
+ self._validate_positive_int(limit, "limit")
215
+
216
+ results = self.knowledge_graph.code_index.query(query, n_results=limit)
217
+ metadatas = results.get("metadatas", [[]])[0]
218
+
219
+ if not metadatas:
220
+ return {"query": query, "results": [], "text": f"No results found for '{query}'."}
221
+
222
+ structured_results = [
223
+ {
224
+ "id": res.get("id"),
225
+ "content": res.get("content"),
226
+ "declared_entities": res.get("declared_entities"),
227
+ "called_entities": res.get("called_entities")
228
+ }
229
+ for res in metadatas
230
+ ]
231
+
232
+ return {
233
+ "query": query,
234
+ "count": len(structured_results),
235
+ "results": structured_results,
236
+ "text": f"Found {len(structured_results)} result(s) for query '{query}'."
237
+ }
238
+ except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
239
+ return self._handle_error(e, "search_nodes")
240
+ except Exception as e:
241
+ return self._handle_error(e, "search_nodes")
242
+
243
+ @self.app.tool(
244
+ description="Get overall statistics about the knowledge graph, including node and edge counts, types, and relations."
245
+ )
246
+ @observe(as_type='tool')
247
+ async def get_graph_stats() -> dict:
248
+ g = self.knowledge_graph.graph
249
+ num_nodes = g.number_of_nodes()
250
+ num_edges = g.number_of_edges()
251
+
252
+ node_types = {}
253
+ for _, node_attrs in g.nodes(data=True):
254
+ node_type = getattr(node_attrs['data'], 'node_type', 'Unknown')
255
+ node_types[node_type] = node_types.get(node_type, 0) + 1
256
+
257
+ edge_relations = {}
258
+ for _, _, attrs in g.edges(data=True):
259
+ relation = attrs.get('relation', 'Unknown')
260
+ edge_relations[relation] = edge_relations.get(relation, 0) + 1
261
+
262
+ return {
263
+ "total_nodes": num_nodes,
264
+ "total_edges": num_edges,
265
+ "node_types": node_types,
266
+ "edge_relations": edge_relations,
267
+ "text": f"Graph with {num_nodes} nodes, {num_edges} edges, {len(node_types)} node types, and {len(edge_relations)} relation types."
268
+ }
269
+
270
+ @self.app.tool(
271
+ description="List nodes of a specific type in the knowledge graph."
272
+ )
273
+ @observe(as_type='tool')
274
+ async def list_nodes_by_type(
275
+ node_type: Annotated[str, "The type of nodes to list (e.g., 'function', 'class', 'file')."],
276
+ limit: Annotated[int, "Maximum number of nodes to return."] = 20
277
+ ) -> dict:
278
+ g = self.knowledge_graph.graph
279
+ matching_nodes = [
280
+ {
281
+ "id": node_id,
282
+ "name": getattr(data['data'], 'name', 'Unknown')
283
+ }
284
+ for node_id, data in g.nodes(data=True)
285
+ if getattr(data['data'], 'node_type', None) == node_type
286
+ ][:limit]
287
+
288
+ if not matching_nodes:
289
+ return {"node_type": node_type, "results": [], "text": f"No nodes found of type '{node_type}'."}
290
+
291
+ return {
292
+ "node_type": node_type,
293
+ "count": len(matching_nodes),
294
+ "results": matching_nodes,
295
+ "text": f"Found {len(matching_nodes)} node(s) of type '{node_type}'."
296
+ }
297
+
298
+ @self.app.tool(
299
+ description="Get all nodes directly connected to a given node, including the relationship type."
300
+ )
301
+ @observe(as_type='tool')
302
+ async def get_neighbors(
303
+ node_id: Annotated[str, "The ID of the node whose neighbors to retrieve."]
304
+ ) -> dict:
305
+ """Get all nodes directly connected to this node, with their relationship types."""
306
+ try:
307
+ self._validate_node_exists(node_id)
308
+
309
+ neighbors = self.knowledge_graph.get_neighbors(node_id)
310
+ if not neighbors:
311
+ return {
312
+ "node_id": node_id,
313
+ "neighbors": [],
314
+ "text": f"No neighbors found for node '{node_id}'"
315
+ }
316
+
317
+ neighbor_list = []
318
+ for neighbor in neighbors[:20]:
319
+ neighbor_info = {
320
+ "id": neighbor.id,
321
+ "name": getattr(neighbor, 'name', 'Unknown'),
322
+ "type": neighbor.node_type,
323
+ "relation": None
324
+ }
325
+
326
+ if self.knowledge_graph.graph.has_edge(node_id, neighbor.id):
327
+ edge_data = self.knowledge_graph.graph.get_edge_data(node_id, neighbor.id)
328
+ neighbor_info["relation"] = edge_data.get('relation', 'Unknown')
329
+ neighbor_info["direction"] = "outgoing"
330
+ elif self.knowledge_graph.graph.has_edge(neighbor.id, node_id):
331
+ edge_data = self.knowledge_graph.graph.get_edge_data(neighbor.id, node_id)
332
+ neighbor_info["relation"] = edge_data.get('relation', 'Unknown')
333
+ neighbor_info["direction"] = "incoming"
334
+
335
+ neighbor_list.append(neighbor_info)
336
+
337
+ text = f"Neighbors of '{node_id}' ({len(neighbors)} total):\n\n"
338
+ for neighbor in neighbor_list:
339
+ text += f"- {neighbor['id']}: {neighbor['name']} ({neighbor['type']})\n"
340
+ if neighbor['relation']:
341
+ arrow = "β†’" if neighbor['direction'] == "outgoing" else "←"
342
+ text += f" {arrow} Relation: {neighbor['relation']}\n"
343
+
344
+ if len(neighbors) > 20:
345
+ text += f"\n... and {len(neighbors) - 20} more neighbors\n"
346
+
347
+ return {
348
+ "node_id": node_id,
349
+ "total_neighbors": len(neighbors),
350
+ "neighbors": neighbor_list,
351
+ "has_more": len(neighbors) > 20,
352
+ "text": text
353
+ }
354
+ except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
355
+ return self._handle_error(e, "get_neighbors")
356
+ except Exception as e:
357
+ return self._handle_error(e, "get_neighbors")
358
+
359
+ @self.app.tool(
360
+ description="Find where an entity (function, class, variable, etc.) is declared or defined in the codebase."
361
+ )
362
+ @observe(as_type='tool')
363
+ async def go_to_definition(
364
+ entity_name: Annotated[str, "The name of the entity to find the definition for."]
365
+ ) -> dict:
366
+ """Find where an entity is declared/defined in the codebase."""
367
+ try:
368
+ self._validate_entity_exists(entity_name)
369
+
370
+ entity_info = self.knowledge_graph.entities[entity_name]
371
+ declaring_chunks = entity_info.get('declaring_chunk_ids', [])
372
+
373
+ if not declaring_chunks:
374
+ return {
375
+ "entity_name": entity_name,
376
+ "declarations": [],
377
+ "text": f"Entity '{entity_name}' found but no declarations identified."
378
+ }
379
+
380
+ declarations = []
381
+ for chunk_id in declaring_chunks[:5]:
382
+ if chunk_id in self.knowledge_graph.graph:
383
+ chunk = self.knowledge_graph.graph.nodes[chunk_id]['data']
384
+ content_preview = chunk.content[:150] + "..." if len(chunk.content) > 150 else chunk.content
385
+ declarations.append({
386
+ "chunk_id": chunk_id,
387
+ "file_path": chunk.path,
388
+ "order_in_file": chunk.order_in_file,
389
+ "content_preview": content_preview
390
+ })
391
+
392
+ text = f"Definition(s) for '{entity_name}':\n\n"
393
+ text += f"Type: {', '.join(entity_info.get('type', ['Unknown']))}\n"
394
+ if entity_info.get('dtype'):
395
+ text += f"Data Type: {entity_info['dtype']}\n"
396
+ text += f"\nDeclared in {len(declaring_chunks)} location(s):\n\n"
397
+
398
+ for decl in declarations:
399
+ text += f"- Chunk: {decl['chunk_id']}\n"
400
+ text += f" File: {decl['file_path']}\n"
401
+ text += f" Order: {decl['order_in_file']}\n"
402
+ text += f" Content: {decl['content_preview']}\n\n"
403
+
404
+ if len(declaring_chunks) > 5:
405
+ text += f"... and {len(declaring_chunks) - 5} more locations\n"
406
+
407
+ return {
408
+ "entity_name": entity_name,
409
+ "type": entity_info.get('type', []),
410
+ "dtype": entity_info.get('dtype'),
411
+ "total_declarations": len(declaring_chunks),
412
+ "declarations": declarations,
413
+ "has_more": len(declaring_chunks) > 5,
414
+ "text": text
415
+ }
416
+ except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
417
+ return self._handle_error(e, "go_to_definition")
418
+ except Exception as e:
419
+ return self._handle_error(e, "go_to_definition")
420
+
421
+ @self.app.tool(
422
+ description="Find all usages or calls of an entity (function, class, variable, etc.) in the codebase."
423
+ )
424
+ @observe(as_type='tool')
425
+ async def find_usages(
426
+ entity_name: Annotated[str, "The name of the entity to find usages for."],
427
+ limit: Annotated[int, "Maximum number of usages to return."] = 20
428
+ ) -> dict:
429
+ """Find where an entity is used/called in the codebase."""
430
+ try:
431
+ self._validate_entity_exists(entity_name)
432
+ self._validate_positive_int(limit, "limit")
433
+
434
+ entity_info = self.knowledge_graph.entities[entity_name]
435
+ calling_chunks = entity_info.get('calling_chunk_ids', [])
436
+
437
+ if not calling_chunks:
438
+ return {
439
+ "entity_name": entity_name,
440
+ "usages": [],
441
+ "text": f"Entity '{entity_name}' found but no usages identified."
442
+ }
443
+
444
+ usages = []
445
+ for chunk_id in calling_chunks[:limit]:
446
+ if chunk_id in self.knowledge_graph.graph:
447
+ chunk = self.knowledge_graph.graph.nodes[chunk_id]['data']
448
+ content_preview = chunk.content[:150] + "..." if len(chunk.content) > 150 else chunk.content
449
+ usages.append({
450
+ "chunk_id": chunk_id,
451
+ "file_path": chunk.path,
452
+ "order_in_file": chunk.order_in_file,
453
+ "content_preview": content_preview
454
+ })
455
+
456
+ text = f"Usages of '{entity_name}' ({len(calling_chunks)} total):\n\n"
457
+ for usage in usages:
458
+ text += f"- {usage['file_path']} (chunk {usage['order_in_file']})\n"
459
+ text += f" Content: {usage['content_preview']}\n\n"
460
+
461
+ if len(calling_chunks) > limit:
462
+ text += f"\n... and {len(calling_chunks) - limit} more usages\n"
463
+
464
+ return {
465
+ "entity_name": entity_name,
466
+ "total_usages": len(calling_chunks),
467
+ "usages": usages,
468
+ "has_more": len(calling_chunks) > limit,
469
+ "text": text
470
+ }
471
+ except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
472
+ return self._handle_error(e, "find_usages")
473
+ except Exception as e:
474
+ return self._handle_error(e, "find_usages")
475
+
476
+ @self.app.tool(
477
+ description="Get an overview of the structure of a file, including its chunks and declared entities."
478
+ )
479
+ @observe(as_type='tool')
480
+ async def get_file_structure(
481
+ file_path: Annotated[str, "The path of the file to get the structure for."]
482
+ ) -> dict:
483
+ """Get an overview of chunks and entities in a specific file."""
484
+ try:
485
+ self._validate_node_exists(file_path)
486
+
487
+ file_node = self.knowledge_graph.graph.nodes[file_path]['data']
488
+ chunks = self.knowledge_graph.get_chunks_of_file(file_path)
489
+
490
+ declared_entities = []
491
+ if hasattr(file_node, 'declared_entities') and file_node.declared_entities:
492
+ for entity in file_node.declared_entities[:15]:
493
+ if isinstance(entity, dict):
494
+ declared_entities.append({
495
+ "name": entity.get('name', '?'),
496
+ "type": entity.get('type', '?')
497
+ })
498
+ else:
499
+ declared_entities.append({"name": str(entity), "type": "Unknown"})
500
+
501
+ chunk_list = []
502
+ for chunk in chunks[:10]:
503
+ chunk_list.append({
504
+ "id": chunk.id,
505
+ "order_in_file": chunk.order_in_file,
506
+ "description": chunk.description[:80] + "..." if chunk.description and len(chunk.description) > 80 else chunk.description
507
+ })
508
+
509
+ text = f"File Structure: {file_node.name}\n"
510
+ text += f"Path: {file_path}\n"
511
+ text += f"Language: {getattr(file_node, 'language', 'Unknown')}\n"
512
+ text += f"Total Chunks: {len(chunks)}\n\n"
513
+
514
+ if declared_entities:
515
+ text += f"Declared Entities ({len(file_node.declared_entities)}):\n"
516
+ for entity in declared_entities:
517
+ text += f" - {entity['name']} ({entity['type']})\n"
518
+ if len(file_node.declared_entities) > 15:
519
+ text += f" ... and {len(file_node.declared_entities) - 15} more\n"
520
+
521
+ text += f"\nChunks:\n"
522
+ for chunk_info in chunk_list:
523
+ text += f" [{chunk_info['order_in_file']}] {chunk_info['id']}\n"
524
+ if chunk_info['description']:
525
+ text += f" {chunk_info['description']}\n"
526
+
527
+ if len(chunks) > 10:
528
+ text += f" ... and {len(chunks) - 10} more chunks\n"
529
+
530
+ return {
531
+ "file_path": file_path,
532
+ "file_name": file_node.name,
533
+ "language": getattr(file_node, 'language', 'Unknown'),
534
+ "total_chunks": len(chunks),
535
+ "total_declared_entities": len(file_node.declared_entities) if hasattr(file_node, 'declared_entities') else 0,
536
+ "declared_entities": declared_entities,
537
+ "chunks": chunk_list,
538
+ "has_more_entities": hasattr(file_node, 'declared_entities') and len(file_node.declared_entities) > 15,
539
+ "has_more_chunks": len(chunks) > 10,
540
+ "text": text
541
+ }
542
+ except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
543
+ return self._handle_error(e, "get_file_structure")
544
+ except Exception as e:
545
+ return self._handle_error(e, "get_file_structure")
546
+
547
+ @self.app.tool(
548
+ description="Get chunks related to a given chunk by a specific relationship (e.g., 'calls', 'contains')."
549
+ )
550
+ @observe(as_type='tool')
551
+ async def get_related_chunks(
552
+ chunk_id: Annotated[str, "The ID of the chunk to find related chunks for."],
553
+ relation_type: Annotated[str, "The type of relationship to filter by (e.g., 'calls', 'contains')."] = "calls"
554
+ ) -> dict:
555
+ """Get chunks related to this chunk by a specific relationship (e.g., 'calls', 'contains')."""
556
+ try:
557
+ self._validate_node_exists(chunk_id)
558
+
559
+ related = []
560
+ for _, target, attrs in self.knowledge_graph.graph.out_edges(chunk_id, data=True):
561
+ if attrs.get('relation') == relation_type:
562
+ target_node = self.knowledge_graph.graph.nodes[target]['data']
563
+ related.append({
564
+ "id": target,
565
+ "file_path": getattr(target_node, 'path', 'Unknown'),
566
+ "entity_name": attrs.get('entity_name')
567
+ })
568
+
569
+ if not related:
570
+ return {
571
+ "chunk_id": chunk_id,
572
+ "relation_type": relation_type,
573
+ "related_chunks": [],
574
+ "text": f"No chunks found with '{relation_type}' relationship from '{chunk_id}'"
575
+ }
576
+
577
+ text = f"Chunks related to '{chunk_id}' via '{relation_type}' ({len(related)} total):\n\n"
578
+ for chunk in related[:15]:
579
+ text += f"- {chunk['id']}\n"
580
+ text += f" File: {chunk['file_path']}\n"
581
+ if chunk['entity_name']:
582
+ text += f" Entity: {chunk['entity_name']}\n"
583
+
584
+ if len(related) > 15:
585
+ text += f"\n... and {len(related) - 15} more\n"
586
+
587
+ return {
588
+ "chunk_id": chunk_id,
589
+ "relation_type": relation_type,
590
+ "total_related": len(related),
591
+ "related_chunks": related[:15],
592
+ "has_more": len(related) > 15,
593
+ "text": text
594
+ }
595
+ except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
596
+ return self._handle_error(e, "get_related_chunks")
597
+ except Exception as e:
598
+ return self._handle_error(e, "get_related_chunks")
599
+
600
+ @self.app.tool(
601
+ description="List all entities tracked in the knowledge graph, including their types, declaration, and usage counts."
602
+ )
603
+ @observe(as_type='tool')
604
+ async def list_all_entities(
605
+ limit: Annotated[int, "Maximum number of entities to return."] = 50
606
+ ) -> dict:
607
+ """List all entities tracked in the knowledge graph with their metadata."""
608
+ if not self.knowledge_graph.entities:
609
+ return {
610
+ "entities": [],
611
+ "text": "No entities found in the knowledge graph."
612
+ }
613
+
614
+ entities = []
615
+ for entity_name, info in list(self.knowledge_graph.entities.items())[:limit]:
616
+ entities.append({
617
+ "name": entity_name,
618
+ "types": info.get('type', ['Unknown']),
619
+ "declaration_count": len(info.get('declaring_chunk_ids', [])),
620
+ "usage_count": len(info.get('calling_chunk_ids', []))
621
+ })
622
+
623
+ text = f"All Entities ({len(self.knowledge_graph.entities)} total):\n\n"
624
+ for i, entity in enumerate(entities, 1):
625
+ text += f"{i}. {entity['name']}\n"
626
+ text += f" Types: {', '.join(entity['types'])}\n"
627
+ text += f" Declarations: {entity['declaration_count']}\n"
628
+ text += f" Usages: {entity['usage_count']}\n\n"
629
+
630
+ if len(self.knowledge_graph.entities) > limit:
631
+ text += f"... and {len(self.knowledge_graph.entities) - limit} more entities\n"
632
+
633
+ return {
634
+ "total_entities": len(self.knowledge_graph.entities),
635
+ "entities": entities,
636
+ "has_more": len(self.knowledge_graph.entities) > limit,
637
+ "text": text
638
+ }
639
+
640
+ # --- New Tools ---
641
+ @self.app.tool(
642
+ description="Show the diff between two code chunks or nodes by their IDs."
643
+ )
644
+ @observe(as_type='tool')
645
+ async def diff_chunks(
646
+ node_id_1: Annotated[str, "The ID of the first node/chunk."],
647
+ node_id_2: Annotated[str, "The ID of the second node/chunk."]
648
+ ) -> dict:
649
+ try:
650
+ import difflib
651
+ self._validate_node_exists(node_id_1)
652
+ self._validate_node_exists(node_id_2)
653
+
654
+ g = self.knowledge_graph.graph
655
+ content1 = getattr(g.nodes[node_id_1]['data'], 'content', None)
656
+ content2 = getattr(g.nodes[node_id_2]['data'], 'content', None)
657
+
658
+ if not content1 or not content2:
659
+ raise InvalidInputError("One or both nodes have no content.")
660
+
661
+ diff = list(difflib.unified_diff(
662
+ content1.splitlines(), content2.splitlines(),
663
+ fromfile=node_id_1, tofile=node_id_2, lineterm=""
664
+ ))
665
+
666
+ diff_text = "\n".join(diff) if diff else "No differences."
667
+
668
+ return {
669
+ "node_id_1": node_id_1,
670
+ "node_id_2": node_id_2,
671
+ "has_differences": bool(diff),
672
+ "diff": diff,
673
+ "text": diff_text
674
+ }
675
+ except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
676
+ return self._handle_error(e, "diff_chunks")
677
+ except Exception as e:
678
+ return self._handle_error(e, "diff_chunks")
679
+
680
+ @self.app.tool(
681
+ description="Show a tree view of the repository or a subtree starting from a given node ID."
682
+ )
683
+ @observe(as_type='tool')
684
+ async def print_tree(
685
+ root_id: Annotated[Optional[str], "The node ID to start the tree from (default: repo root)."] = 'root',
686
+ max_depth: Annotated[int, "Maximum depth to show."] = 3
687
+ ) -> dict:
688
+ try:
689
+ g = self.knowledge_graph.graph
690
+
691
+ def build_tree(node_id, depth, tree_data):
692
+ if depth > max_depth:
693
+ return
694
+ node = g.nodes[node_id]['data']
695
+ node_info = {
696
+ "id": node_id,
697
+ "name": getattr(node, 'name', node_id),
698
+ "type": getattr(node, 'node_type', '?'),
699
+ "depth": depth,
700
+ "children": []
701
+ }
702
+ tree_data.append(node_info)
703
+ children = [t for s, t in g.out_edges(node_id)]
704
+ for child in children:
705
+ build_tree(child, depth + 1, node_info["children"])
706
+
707
+ def format_tree(tree_data):
708
+ result = ""
709
+ for node in tree_data:
710
+ result += " " * node["depth"] + f"- {node['name']} ({node['type']})\n"
711
+ for child in node["children"]:
712
+ result += format_subtree(child)
713
+ return result
714
+
715
+ def format_subtree(node):
716
+ result = " " * node["depth"] + f"- {node['name']} ({node['type']})\n"
717
+ for child in node["children"]:
718
+ result += format_subtree(child)
719
+ return result
720
+
721
+ if root_id is None:
722
+ roots = [n for n, d in g.nodes(data=True) if getattr(d['data'], 'node_type', None) in ('repo', 'directory', 'file')]
723
+ root_id = roots[0] if roots else list(g.nodes)[0]
724
+
725
+ self._validate_node_exists(root_id)
726
+
727
+ tree_data = []
728
+ build_tree(root_id, 0, tree_data)
729
+
730
+ return {
731
+ "root_id": root_id,
732
+ "max_depth": max_depth,
733
+ "tree": tree_data,
734
+ "text": format_tree(tree_data)
735
+ }
736
+ except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
737
+ return self._handle_error(e, "print_tree")
738
+ except Exception as e:
739
+ return self._handle_error(e, "print_tree")
740
+
741
+ @self.app.tool(
742
+ description="Show all relationships (calls, contains, etc.) for a given entity or node."
743
+ )
744
+ @observe(as_type='tool')
745
+ async def entity_relationships(
746
+ node_id: Annotated[str, "The node/entity ID to explore relationships for."]
747
+ ) -> dict:
748
+ try:
749
+ self._validate_node_exists(node_id)
750
+ g = self.knowledge_graph.graph
751
+
752
+ incoming = []
753
+ outgoing = []
754
+
755
+ for source, target, data in g.in_edges(node_id, data=True):
756
+ incoming.append({
757
+ "source": source,
758
+ "target": target,
759
+ "relation": data.get('relation', '?')
760
+ })
761
+
762
+ for source, target, data in g.out_edges(node_id, data=True):
763
+ outgoing.append({
764
+ "source": source,
765
+ "target": target,
766
+ "relation": data.get('relation', '?')
767
+ })
768
+
769
+ text = f"Relationships for '{node_id}':\n"
770
+ for rel in incoming:
771
+ text += f"← {rel['source']} [{rel['relation']}]\n"
772
+ for rel in outgoing:
773
+ text += f"β†’ {rel['target']} [{rel['relation']}]\n"
774
+
775
+ if not incoming and not outgoing:
776
+ text = "No relationships found."
777
+
778
+ return {
779
+ "node_id": node_id,
780
+ "incoming": incoming,
781
+ "outgoing": outgoing,
782
+ "incoming_count": len(incoming),
783
+ "outgoing_count": len(outgoing),
784
+ "text": text
785
+ }
786
+ except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
787
+ return self._handle_error(e, "entity_relationships")
788
+ except Exception as e:
789
+ return self._handle_error(e, "entity_relationships")
790
+
791
+ @self.app.tool(
792
+ description="Search for nodes/entities by type and name substring with fuzzy matching support. For entities, searches by entity_type (e.g., 'class', 'function', 'method'). For other nodes, searches by node_type (e.g., 'file', 'chunk', 'directory')."
793
+ )
794
+ @observe(as_type='tool')
795
+ async def search_by_type_and_name(
796
+ node_type: Annotated[str, "Type of node/entity (e.g., 'function', 'class', 'file', 'chunk', 'directory')."],
797
+ name_query: Annotated[str, "Substring to match in the name (case-insensitive, supports partial matches)."],
798
+ limit: Annotated[int, "Maximum results to return."] = 10,
799
+ fuzzy: Annotated[bool, "Enable fuzzy/partial matching (default: True)."] = True
800
+ ) -> dict:
801
+ import re
802
+ try:
803
+ self._validate_positive_int(limit, "limit")
804
+
805
+ g = self.knowledge_graph.graph
806
+ matches = []
807
+ query_lower = name_query.lower()
808
+
809
+ # Build regex pattern for fuzzy matching
810
+ if fuzzy:
811
+ fuzzy_pattern = '.*'.join(re.escape(c) for c in query_lower)
812
+ fuzzy_regex = re.compile(fuzzy_pattern, re.IGNORECASE)
813
+
814
+ for nid, n in g.nodes(data=True):
815
+ node = n['data']
816
+ node_name = getattr(node, 'name', '')
817
+
818
+ if not node_name:
819
+ continue
820
+
821
+ # Check if name matches the query
822
+ name_matches = False
823
+ if fuzzy:
824
+ if query_lower in node_name.lower() or fuzzy_regex.search(node_name):
825
+ name_matches = True
826
+ else:
827
+ if query_lower in node_name.lower():
828
+ name_matches = True
829
+
830
+ if not name_matches:
831
+ continue
832
+
833
+ # Check type based on node_type
834
+ current_node_type = getattr(node, 'node_type', None)
835
+
836
+ # For entity nodes, check entity_type instead of node_type
837
+ if current_node_type == 'entity':
838
+ entity_type = getattr(node, 'entity_type', '')
839
+
840
+ # Fallback: if entity_type is empty, check the entities dictionary
841
+ if not entity_type and nid in self.knowledge_graph.entities:
842
+ entity_types = self.knowledge_graph.entities[nid].get('type', [])
843
+ entity_type = entity_types[0] if entity_types else ''
844
+
845
+ if entity_type and entity_type.lower() == node_type.lower():
846
+ score = 0 if query_lower == node_name.lower() else (1 if query_lower in node_name.lower() else 2)
847
+ matches.append({
848
+ "id": nid,
849
+ "name": node_name,
850
+ "type": f"entity ({entity_type})",
851
+ "content": getattr(node, 'content', None),
852
+ "score": score
853
+ })
854
+ # For other nodes, check node_type directly
855
+ elif current_node_type == node_type:
856
+ score = 0 if query_lower == node_name.lower() else (1 if query_lower in node_name.lower() else 2)
857
+ matches.append({
858
+ "id": nid,
859
+ "name": node_name,
860
+ "type": current_node_type,
861
+ "content": getattr(node, 'content', None),
862
+ "score": score
863
+ })
864
+
865
+ # Sort by match score (best matches first) and limit results
866
+ matches.sort(key=lambda x: (x['score'], x['name'].lower()))
867
+ matches = matches[:limit]
868
+
869
+ if not matches:
870
+ return {
871
+ "node_type": node_type,
872
+ "name_query": name_query,
873
+ "matches": [],
874
+ "text": f"No matches for type '{node_type}' and name containing '{name_query}'."
875
+ }
876
+
877
+ text = f"Matches for type '{node_type}' and name '{name_query}' ({len(matches)} results):\n"
878
+ for match in matches:
879
+ text += f"- {match['id']}: {match['name']} [{match['type']}]\n"
880
+
881
+ return {
882
+ "node_type": node_type,
883
+ "name_query": name_query,
884
+ "count": len(matches),
885
+ "matches": matches,
886
+ "text": text
887
+ }
888
+ except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
889
+ return self._handle_error(e, "search_by_type_and_name")
890
+ except Exception as e:
891
+ return self._handle_error(e, "search_by_type_and_name")
892
+
893
+ @self.app.tool(
894
+ description="Get the full content of a code chunk along with its surrounding chunks (previous and next)."
895
+ )
896
+ @observe(as_type='tool')
897
+ async def get_chunk_context(
898
+ node_id: Annotated[str, "The node/chunk ID to get context for."]
899
+ ) -> dict:
900
+ from .utils.chunk_utils import organize_chunks_by_file_name, join_organized_chunks
901
+ try:
902
+ self._validate_node_exists(node_id)
903
+
904
+ g = self.knowledge_graph.graph
905
+ current_chunk = g.nodes[node_id]['data']
906
+ previous_chunk = self.knowledge_graph.get_previous_chunk(node_id)
907
+ next_chunk = self.knowledge_graph.get_next_chunk(node_id)
908
+
909
+ # Collect all chunks (previous, current, next)
910
+ chunks = []
911
+ prev_info = None
912
+ next_info = None
913
+ current_info = {
914
+ "id": node_id,
915
+ "content": getattr(current_chunk, 'content', '')
916
+ }
917
+
918
+ if previous_chunk:
919
+ prev_info = {
920
+ "id": previous_chunk.id,
921
+ "content": previous_chunk.content
922
+ }
923
+ chunks.append(previous_chunk)
924
+
925
+ chunks.append(current_chunk)
926
+
927
+ if next_chunk:
928
+ next_info = {
929
+ "id": next_chunk.id,
930
+ "content": next_chunk.content
931
+ }
932
+ chunks.append(next_chunk)
933
+
934
+ # Organize and join chunks
935
+ organized = organize_chunks_by_file_name(chunks)
936
+ full_content = join_organized_chunks(organized)
937
+
938
+ return {
939
+ "node_id": node_id,
940
+ "current_chunk": current_info,
941
+ "previous_chunk": prev_info,
942
+ "next_chunk": next_info,
943
+ "text": full_content
944
+ }
945
+ except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
946
+ return self._handle_error(e, "get_chunk_context")
947
+ except Exception as e:
948
+ return self._handle_error(e, "get_chunk_context")
949
+
950
+ @self.app.tool(
951
+ description="Get statistics for a file or directory: number of entities, lines, chunks, etc."
952
+ )
953
+ @observe(as_type='tool')
954
+ async def get_file_stats(
955
+ path: Annotated[str, "The file or directory path to get statistics for."]
956
+ ) -> dict:
957
+ try:
958
+ g = self.knowledge_graph.graph
959
+ nodes = [n for n, d in g.nodes(data=True) if getattr(d['data'], 'path', None) == path]
960
+
961
+ if not nodes:
962
+ raise NodeNotFoundError(f"No nodes found for path '{path}'.")
963
+
964
+ stats = []
965
+ text = f"Statistics for '{path}':\n"
966
+
967
+ for node_id in nodes:
968
+ node = g.nodes[node_id]['data']
969
+ content = getattr(node, 'content', '')
970
+ declared = getattr(node, 'declared_entities', [])
971
+ called = getattr(node, 'called_entities', [])
972
+ chunks = [t for s, t in g.out_edges(node_id) if getattr(g.nodes[t]['data'], 'node_type', None) == 'chunk']
973
+
974
+ declared_list = []
975
+ for entity in declared[:10]:
976
+ if isinstance(entity, dict):
977
+ declared_list.append({
978
+ "name": entity.get('name', '?'),
979
+ "type": entity.get('type', '?')
980
+ })
981
+ else:
982
+ declared_list.append({"name": str(entity), "type": "Unknown"})
983
+
984
+ called_list = [str(entity) for entity in called[:10]]
985
+
986
+ node_stats = {
987
+ "node_id": node_id,
988
+ "node_type": getattr(node, 'node_type', '?'),
989
+ "lines": len(content.splitlines()) if content else 0,
990
+ "declared_entities_count": len(declared),
991
+ "declared_entities": declared_list,
992
+ "called_entities_count": len(called),
993
+ "called_entities": called_list,
994
+ "chunks_count": len(chunks),
995
+ "has_more_declared": len(declared) > 10,
996
+ "has_more_called": len(called) > 10
997
+ }
998
+ stats.append(node_stats)
999
+
1000
+ text += f"- Node: {node_id} ({node_stats['node_type']})\n"
1001
+ text += f" Lines: {node_stats['lines']}\n"
1002
+
1003
+ if declared_list:
1004
+ text += f" Declared entities ({len(declared)}):\n"
1005
+ for entity in declared_list:
1006
+ text += f" - {entity['name']} ({entity['type']})\n"
1007
+ if len(declared) > 10:
1008
+ text += f" ... and {len(declared) - 10} more\n"
1009
+ else:
1010
+ text += f" Declared entities: 0\n"
1011
+
1012
+ if called_list:
1013
+ text += f" Called entities ({len(called)}):\n"
1014
+ for entity in called_list:
1015
+ text += f" - {entity}\n"
1016
+ if len(called) > 10:
1017
+ text += f" ... and {len(called) - 10} more\n"
1018
+ else:
1019
+ text += f" Called entities: 0\n"
1020
+
1021
+ text += f" Chunks: {len(chunks)}\n"
1022
+
1023
+ return {
1024
+ "path": path,
1025
+ "nodes": stats,
1026
+ "text": text
1027
+ }
1028
+ except (NodeNotFoundError, InvalidInputError, EntityNotFoundError) as e:
1029
+ return self._handle_error(e, "get_file_stats")
1030
+ except Exception as e:
1031
+ return self._handle_error(e, "get_file_stats")
1032
+ # --- End New Tools ---
1033
+ @self.app.tool(
1034
+ description="Search for file names in the repository using a regular expression pattern."
1035
+ )
1036
+ @observe(as_type='tool')
1037
+ async def search_file_names_by_regex(
1038
+ pattern: Annotated[str, "The regular expression pattern to match file names."]
1039
+ ) -> dict:
1040
+ """Search for file names matching a regex pattern."""
1041
+ import re
1042
+ g = self.knowledge_graph.graph
1043
+
1044
+ try:
1045
+ regex = re.compile(pattern)
1046
+ except re.error as e:
1047
+ return {"error": f"Invalid regex pattern: {str(e)}"}
1048
+
1049
+ matches = []
1050
+ for node_id, node_attrs in g.nodes(data=True):
1051
+ node = node_attrs['data']
1052
+ if getattr(node, 'node_type', None) == 'file':
1053
+ file_name = getattr(node, 'name', '') or getattr(node, 'path', '')
1054
+ if regex.search(file_name):
1055
+ matches.append({
1056
+ "node_id": node_id,
1057
+ "file_name": file_name
1058
+ })
1059
+
1060
+ if not matches:
1061
+ return {
1062
+ "pattern": pattern,
1063
+ "matches": [],
1064
+ "text": f"No file names matched the pattern: '{pattern}'"
1065
+ }
1066
+
1067
+ text = f"Files matching pattern '{pattern}':\n"
1068
+ for match in matches[:20]:
1069
+ text += f"- {match['file_name']} (node ID: {match['node_id']})\n"
1070
+
1071
+ if len(matches) > 20:
1072
+ text += f"... and {len(matches) - 20} more\n"
1073
+
1074
+ return {
1075
+ "pattern": pattern,
1076
+ "count": len(matches),
1077
+ "matches": matches[:20],
1078
+ "has_more": len(matches) > 20,
1079
+ "text": text
1080
+ }
1081
+
1082
+ @self.app.tool(
1083
+ description="Find the shortest path between two nodes in the knowledge graph."
1084
+ )
1085
+ @observe(as_type='tool')
1086
+ async def find_path(
1087
+ source_id: Annotated[str, "The ID of the source node."],
1088
+ target_id: Annotated[str, "The ID of the target node."],
1089
+ max_depth: Annotated[int, "Maximum depth to search for a path."] = 5
1090
+ ) -> dict:
1091
+ """Find shortest path between two nodes."""
1092
+ return self.knowledge_graph.find_path(source_id, target_id, max_depth)
1093
+
1094
+ @self.app.tool(
1095
+ description="Extract a subgraph around a node up to a specified depth, optionally filtering by edge types."
1096
+ )
1097
+ @observe(as_type='tool')
1098
+ async def get_subgraph(
1099
+ node_id: Annotated[str, "The ID of the central node."],
1100
+ depth: Annotated[int, "The depth/radius of the subgraph to extract."] = 2,
1101
+ edge_types: Annotated[Optional[list], "Optional list of edge types to include (e.g., ['calls', 'contains'])."] = None
1102
+ ) -> dict:
1103
+ """Extract a subgraph around a node."""
1104
+ return self.knowledge_graph.get_subgraph(node_id, depth, edge_types)
1105
+
1106
+ def run(self, **kwargs):
1107
+ self.app.run(**kwargs)
RepoKnowledgeGraphLib/ModelService.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from openai import OpenAI, AsyncOpenAI
3
+ from dotenv import load_dotenv
4
+ import os
5
+ import logging
6
+ from tenacity import retry, stop_after_attempt, wait_fixed
7
+ import httpx
8
+ from sentence_transformers import SentenceTransformer
9
+
10
+ # Optional torch import for CUDA detection
11
+ try:
12
+ import torch
13
+ _TORCH_AVAILABLE = True
14
+ except Exception:
15
+ torch = None
16
+ _TORCH_AVAILABLE = False
17
+
18
+ from .utils.logger_utils import setup_logger
19
+
20
+ LOGGER_NAME = "MODEL_SERVICE_LOGGER"
21
+ # GENERATION ENV VARIABLES (defaults)
22
+ OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", 'http://0.0.0.0:8000/v1')
23
+ OPENAI_TOKEN = os.getenv("OPENAI_TOKEN", 'no-need')
24
+ MODEL_NAME = os.getenv('MODEL_NAME', "meta-llama/Llama-3.2-3B-Instruct")
25
+ # EMBED ENV VARIABLES (defaults)
26
+ OPENAI_EMBED_BASE_URL = os.getenv("OPENAI_EMBED_BASE_URL", 'http://0.0.0.0:8001/v1')
27
+ OPENAI_EMBED_TOKEN = os.getenv("OPENAI_EMBED_TOKEN", 'no-need')
28
+ EMBED_MODEL_NAME = os.getenv('EMBED_MODEL_NAME', "Alibaba-NLP/gte-Qwen2-1.5B-instruct")
29
+
30
+ # Additional ENV defaults requested
31
+ MAX_TOKENS = int(os.getenv("MAX_TOKENS", 2048))
32
+ TEMPERATURE = float(os.getenv("TEMPERATURE", 0.2))
33
+ TOP_P = float(os.getenv("TOP_P", 0.95))
34
+ FREQUENCY_PENALTY = float(os.getenv("FREQUENCY_PENALTY", 0))
35
+ PRESENCE_PENALTY = float(os.getenv("PRESENCE_PENALTY", 0))
36
+ EMBEDDING_MODEL_URL = os.getenv("EMBEDDING_MODEL_URL", "")
37
+ EMBEDDING_MODEL_API_KEY = os.getenv("EMBEDDING_MODEL_API_KEY", "no_need")
38
+ EMBEDDING_NUMBER_DIMENSIONS = int(os.getenv("EMBEDDING_NUMBER_DIMENSIONS", 1024))
39
+
40
+ STOP_AFTER_ATTEMPT = int(os.getenv("STOP_AFTER_ATTEMPT", 5))
41
+ WAIT_BETWEEN_RETRIES = int(os.getenv("WAIT_BETWEEN_RETRIES", 2))
42
+ REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", 240))
43
+
44
+ # Note: module-level clients remain for backward compatibility but instances will create their own if timeout is overridden.
45
+ long_timeout_client = httpx.Client(timeout=REQUEST_TIMEOUT)
46
+ long_timeout_async_client = httpx.AsyncClient(timeout=REQUEST_TIMEOUT)
47
+
48
+
49
+ class ModelServiceInterface(ABC):
50
+ """
51
+ Abstract base class defining the interface for model services.
52
+ All model services should implement these methods.
53
+ """
54
+
55
+ # accept model_kwargs so variables can be overridden at runtime
56
+ def __init__(self, model_name: str = None, model_kwargs: dict = None):
57
+ setup_logger(LOGGER_NAME)
58
+ self.logger = logging.getLogger(LOGGER_NAME)
59
+
60
+ model_kwargs = model_kwargs or {}
61
+
62
+ # allow overriding via model_kwargs; fall back to module-level defaults
63
+ self.openai_base_url = model_kwargs.get("OPENAI_BASE_URL", OPENAI_BASE_URL)
64
+ self.openai_token = model_kwargs.get("OPENAI_TOKEN", OPENAI_TOKEN)
65
+ # model_name param takes precedence, then model_kwargs then default env
66
+ self.model_name = model_name or model_kwargs.get("MODEL_NAME", MODEL_NAME)
67
+
68
+ # embed defaults (may be overridden by subclasses or model_kwargs)
69
+ self.openai_embed_base_url = model_kwargs.get("OPENAI_EMBED_BASE_URL", OPENAI_EMBED_BASE_URL)
70
+ self.openai_embed_token = model_kwargs.get("OPENAI_EMBED_TOKEN", OPENAI_EMBED_TOKEN)
71
+ self.embed_model_name = model_kwargs.get("EMBED_MODEL_NAME", EMBED_MODEL_NAME)
72
+
73
+ # other configurable parameters
74
+ self.max_tokens = int(model_kwargs.get("MAX_TOKENS", MAX_TOKENS))
75
+ self.temperature = float(model_kwargs.get("TEMPERATURE", TEMPERATURE))
76
+ self.top_p = float(model_kwargs.get("TOP_P", TOP_P))
77
+ self.frequency_penalty = float(model_kwargs.get("FREQUENCY_PENALTY", FREQUENCY_PENALTY))
78
+ self.presence_penalty = float(model_kwargs.get("PRESENCE_PENALTY", PRESENCE_PENALTY))
79
+ self.embedding_model_url = model_kwargs.get("EMBEDDING_MODEL_URL", EMBEDDING_MODEL_URL)
80
+ self.embedding_model_api_key = model_kwargs.get("EMBEDDING_MODEL_API_KEY", EMBEDDING_MODEL_API_KEY)
81
+ self.embedding_number_dimensions = int(model_kwargs.get("EMBEDDING_NUMBER_DIMENSIONS", EMBEDDING_NUMBER_DIMENSIONS))
82
+
83
+ self.stop_after_attempt = int(model_kwargs.get("STOP_AFTER_ATTEMPT", STOP_AFTER_ATTEMPT))
84
+ self.wait_between_retries = int(model_kwargs.get("WAIT_BETWEEN_RETRIES", WAIT_BETWEEN_RETRIES))
85
+ request_timeout = int(model_kwargs.get("REQUEST_TIMEOUT", REQUEST_TIMEOUT))
86
+
87
+ # create per-instance httpx clients in case REQUEST_TIMEOUT was overridden
88
+ self.long_timeout_client = httpx.Client(timeout=request_timeout)
89
+ self.long_timeout_async_client = httpx.AsyncClient(timeout=request_timeout)
90
+
91
+ # Initialize query client (shared by all implementations)
92
+ self.client = OpenAI(
93
+ base_url=self.openai_base_url,
94
+ api_key=self.openai_token,
95
+ http_client=self.long_timeout_client,
96
+ )
97
+ self.async_client = AsyncOpenAI(
98
+ base_url=self.openai_base_url,
99
+ api_key=self.openai_token,
100
+ http_client=self.long_timeout_async_client,
101
+ )
102
+
103
+ @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
104
+ def query(self, prompt: str, model_name: str) -> str:
105
+ """Query the model with a prompt."""
106
+ if model_name is None:
107
+ model_name = self.model_name
108
+ completion = self.client.chat.completions.create(
109
+ model=model_name,
110
+ messages=[
111
+ {"role": "user", "content": prompt}
112
+ ]
113
+ )
114
+ return completion.choices[0].message.content
115
+
116
+ @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
117
+ def query_with_instructions(self, prompt: str, instructions: str, model_name: str) -> str:
118
+ """Query the model with additional system instructions."""
119
+ if model_name is None:
120
+ model_name = self.model_name
121
+ completion = self.client.chat.completions.create(
122
+ model=model_name,
123
+ messages=[
124
+ {"role": "system", "content": instructions},
125
+ {"role": "user", "content": prompt}
126
+ ]
127
+ )
128
+ return completion.choices[0].message.content
129
+
130
+ @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
131
+ async def query_async(self, prompt: str, model_name: str ) -> str:
132
+ """Async version of query."""
133
+ if model_name is None:
134
+ model_name = self.model_name
135
+ completion = await self.async_client.chat.completions.create(
136
+ model=model_name,
137
+ messages=[
138
+ {"role": "user", "content": prompt}
139
+ ]
140
+ )
141
+ return completion.choices[0].message.content
142
+
143
+ @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
144
+ async def query_with_instructions_async(self, prompt: str, instructions: str, model_name: str) -> str:
145
+ """Async version of query with instructions."""
146
+ if model_name is None:
147
+ model_name = self.model_name
148
+ completion = await self.async_client.chat.completions.create(
149
+ model=model_name,
150
+ messages=[
151
+ {"role": "system", "content": instructions},
152
+ {"role": "user", "content": prompt}
153
+ ]
154
+ )
155
+ return completion.choices[0].message.content
156
+
157
+ @abstractmethod
158
+ def embed(self, text_to_embed: str) -> list:
159
+ """Embed text using the configured embedding model."""
160
+ pass
161
+
162
+ @abstractmethod
163
+ async def embed_async(self, text_to_embed: str) -> list:
164
+ """Async version of embed."""
165
+ pass
166
+
167
+ @abstractmethod
168
+ def embed_chunk_code(self, code_to_embed: str) -> list:
169
+ """Embed code chunk for storage/indexing."""
170
+ pass
171
+
172
+ @abstractmethod
173
+ def embed_query(self, query_to_embed: str) -> list:
174
+ """Embed query for retrieval."""
175
+ pass
176
+
177
+ @abstractmethod
178
+ def embed_batch(self, texts_to_embed: list[str]) -> list[list]:
179
+ """Embed multiple texts in a batch for better performance."""
180
+ pass
181
+
182
+ @abstractmethod
183
+ def embed_chunk_code_batch(self, codes_to_embed: list[str]) -> list[list]:
184
+ """Embed multiple code chunks in a batch for storage/indexing."""
185
+ pass
186
+
187
+
188
+ class OpenAIModelService(ModelServiceInterface):
189
+ """
190
+ Model service that uses OpenAI client for both queries and embeddings.
191
+ """
192
+
193
+ def __init__(self, model_name: str = None, embed_model_name: str = None, model_kwargs: dict = None):
194
+ # forward model_kwargs to base so it can set instance-wide config
195
+ super().__init__(model_name=model_name, model_kwargs=model_kwargs)
196
+
197
+ # allow override of embed model name via param or model_kwargs
198
+ model_kwargs = model_kwargs or {}
199
+ self.embed_model_name = embed_model_name or model_kwargs.get("EMBED_MODEL_NAME", self.embed_model_name)
200
+
201
+ # embed client should use the instance-level embed base/token
202
+ self.embed_client = OpenAI(
203
+ base_url=model_kwargs.get("OPENAI_EMBED_BASE_URL", self.openai_embed_base_url),
204
+ api_key=model_kwargs.get("OPENAI_EMBED_TOKEN", self.openai_embed_token),
205
+ http_client=self.long_timeout_client,
206
+ )
207
+ self.async_embed_client = AsyncOpenAI(
208
+ base_url=model_kwargs.get("OPENAI_EMBED_BASE_URL", self.openai_embed_base_url),
209
+ api_key=model_kwargs.get("OPENAI_EMBED_TOKEN", self.openai_embed_token),
210
+ http_client=self.long_timeout_async_client,
211
+ )
212
+
213
+ @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
214
+ def embed(self, text_to_embed: str) -> list:
215
+ """Embed text using OpenAI embeddings API."""
216
+ response = self.embed_client.embeddings.create(
217
+ input=text_to_embed,
218
+ model=self.embed_model_name,
219
+ )
220
+ return response.data[0].embedding
221
+
222
+ @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
223
+ async def embed_async(self, text_to_embed: str) -> list:
224
+ """Async version of embed using OpenAI embeddings API."""
225
+ response = await self.async_embed_client.embeddings.create(
226
+ input=text_to_embed,
227
+ model=self.embed_model_name,
228
+ )
229
+ return response.data[0].embedding
230
+
231
+ def embed_chunk_code(self, code_to_embed: str) -> list:
232
+ """Embed code chunk using OpenAI embeddings API (same as embed)."""
233
+ return self.embed(code_to_embed)
234
+
235
+ def embed_query(self, query_to_embed: str) -> list:
236
+ """Embed query using OpenAI embeddings API (same as embed)."""
237
+ return self.embed(query_to_embed)
238
+
239
+ @retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT), wait=wait_fixed(WAIT_BETWEEN_RETRIES))
240
+ def embed_batch(self, texts_to_embed: list[str]) -> list[list]:
241
+ """Embed multiple texts in a batch using OpenAI embeddings API."""
242
+ if not texts_to_embed:
243
+ return []
244
+ response = self.embed_client.embeddings.create(
245
+ input=texts_to_embed,
246
+ model=self.embed_model_name,
247
+ )
248
+ return [item.embedding for item in response.data]
249
+
250
+ def embed_chunk_code_batch(self, codes_to_embed: list[str]) -> list[list]:
251
+ """Embed multiple code chunks in a batch using OpenAI embeddings API."""
252
+ return self.embed_batch(codes_to_embed)
253
+
254
+
255
+ class SentenceTransformersModelService(ModelServiceInterface):
256
+ """
257
+ Model service that uses OpenAI client for queries and SentenceTransformers for embeddings.
258
+ Optimized for high-throughput batch embedding with GPU support.
259
+ """
260
+
261
+ def __init__(self, model_name: str = None, embed_model_name: str = None, model_kwargs: dict = None, skip_embedder: bool = False):
262
+ super().__init__(model_name=model_name, model_kwargs=model_kwargs)
263
+ model_kwargs = model_kwargs or {}
264
+ # embed_model_name may be overridden by model_kwargs
265
+ self.embed_model_name = embed_model_name or model_kwargs.get("EMBED_MODEL_NAME", self.embed_model_name)
266
+ self.skip_embedder = skip_embedder
267
+ self.embedding_model = None
268
+
269
+ if skip_embedder:
270
+ self.logger.info('Skipping embedder initialization (keyword-only mode)')
271
+ self.device = "cpu"
272
+ self.encode_batch_size = 32
273
+ return
274
+
275
+ # Debug GPU detection
276
+ self.logger.info(f'PyTorch available: {_TORCH_AVAILABLE}')
277
+ if _TORCH_AVAILABLE:
278
+ self.logger.info(f'CUDA available: {torch.cuda.is_available()}')
279
+ self.logger.info(f'CUDA device count: {torch.cuda.device_count()}')
280
+ if torch.cuda.is_available():
281
+ self.logger.info(f'CUDA device name: {torch.cuda.get_device_name(0)}')
282
+
283
+ # Select device: prefer CUDA if available
284
+ self.device = "cuda" if (_TORCH_AVAILABLE and torch.cuda.is_available()) else "cpu"
285
+ self.logger.info(f'Initializing SentenceTransformer on device: {self.device}')
286
+
287
+ # Set batch size based on device and available memory
288
+ # Larger batch sizes significantly improve GPU throughput
289
+ self.encode_batch_size = int(model_kwargs.get("ENCODE_BATCH_SIZE", 64 if self.device == "cuda" else 32))
290
+
291
+ # Show CUDA memory info if available
292
+ if self.device == "cuda" and _TORCH_AVAILABLE:
293
+ try:
294
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
295
+ self.logger.info(f'GPU memory available: {gpu_memory:.2f} GB')
296
+ # Adjust batch size based on available GPU memory
297
+ if gpu_memory > 16:
298
+ self.encode_batch_size = max(self.encode_batch_size, 128)
299
+ elif gpu_memory > 8:
300
+ self.encode_batch_size = max(self.encode_batch_size, 64)
301
+ except Exception as e:
302
+ self.logger.warning(f'Could not get GPU memory info: {e}')
303
+
304
+ self.logger.info(f'Using encode batch size: {self.encode_batch_size}')
305
+
306
+ # Initialize embedding model on the chosen device with performance optimizations
307
+ self.embedding_model = SentenceTransformer(
308
+ self.embed_model_name,
309
+ trust_remote_code=True,
310
+ device=self.device
311
+ )
312
+
313
+ # Enable half precision for faster inference on CUDA
314
+ if self.device == "cuda" and _TORCH_AVAILABLE:
315
+ try:
316
+ # Check if model supports half precision
317
+ self.embedding_model.half()
318
+ self.logger.info('Enabled half precision (FP16) for faster GPU inference')
319
+ except Exception as e:
320
+ self.logger.warning(f'Could not enable half precision: {e}')
321
+
322
+ def _check_embedder(self):
323
+ """Check if embedder is available, raise error if not."""
324
+ if self.skip_embedder or self.embedding_model is None:
325
+ raise RuntimeError(
326
+ "Embedding model not initialized. This model service was created with skip_embedder=True "
327
+ "(keyword-only mode). To use embeddings, set index_type to 'hybrid' or 'embedding-only'."
328
+ )
329
+
330
+ def embed(self, text_to_embed: str) -> list:
331
+ """Embed text using SentenceTransformers."""
332
+ self._check_embedder()
333
+ embeddings = self.embedding_model.encode(
334
+ [text_to_embed],
335
+ convert_to_numpy=True,
336
+ show_progress_bar=False
337
+ )
338
+ return embeddings[0].tolist() if hasattr(embeddings[0], 'tolist') else list(embeddings[0])
339
+
340
+ async def embed_async(self, text_to_embed: str) -> list:
341
+ """
342
+ Async version of embed using SentenceTransformers.
343
+ Note: SentenceTransformers doesn't have native async support,
344
+ so this runs synchronously but maintains the async interface.
345
+ """
346
+ return self.embed(text_to_embed)
347
+
348
+ def embed_chunk_code(self, code_to_embed: str) -> list:
349
+ """Embed code chunk using SentenceTransformers (no special prompt)."""
350
+ self._check_embedder()
351
+ self.logger.debug(f'Embedding code using {self.embed_model_name}')
352
+ embeddings = self.embedding_model.encode(
353
+ [code_to_embed],
354
+ convert_to_numpy=True,
355
+ show_progress_bar=False
356
+ )
357
+ return embeddings[0].tolist() if hasattr(embeddings[0], 'tolist') else list(embeddings[0])
358
+
359
+ def embed_query(self, query_to_embed: str) -> list:
360
+ """Embed query using SentenceTransformers with retrieval prompt."""
361
+ self._check_embedder()
362
+ self.logger.debug(f'Embedding query using {self.embed_model_name}')
363
+ embeddings = self.embedding_model.encode(
364
+ [query_to_embed],
365
+ prompt='Given this prompt, retrieve relevant content\n Query:',
366
+ convert_to_numpy=True,
367
+ show_progress_bar=False
368
+ )
369
+ return embeddings[0].tolist() if hasattr(embeddings[0], 'tolist') else list(embeddings[0])
370
+
371
+ def embed_batch(self, texts_to_embed: list[str]) -> list[list]:
372
+ """Embed multiple texts in a batch using SentenceTransformers with optimized settings."""
373
+ if not texts_to_embed:
374
+ return []
375
+ self._check_embedder()
376
+ self.logger.info(f'Batch embedding {len(texts_to_embed)} texts using {self.embed_model_name}')
377
+ embeddings = self.embedding_model.encode(
378
+ texts_to_embed,
379
+ batch_size=self.encode_batch_size,
380
+ convert_to_numpy=True,
381
+ show_progress_bar=len(texts_to_embed) > 100, # Only show progress for large batches
382
+ normalize_embeddings=True # Normalize for better similarity computation
383
+ )
384
+ return [emb.tolist() if hasattr(emb, 'tolist') else list(emb) for emb in embeddings]
385
+
386
+ def embed_chunk_code_batch(self, codes_to_embed: list[str]) -> list[list]:
387
+ """Embed multiple code chunks in a batch using SentenceTransformers with optimized settings."""
388
+ if not codes_to_embed:
389
+ return []
390
+ self._check_embedder()
391
+ self.logger.info(f'Batch embedding {len(codes_to_embed)} code chunks using {self.embed_model_name}')
392
+ embeddings = self.embedding_model.encode(
393
+ codes_to_embed,
394
+ batch_size=self.encode_batch_size,
395
+ convert_to_numpy=True,
396
+ show_progress_bar=len(codes_to_embed) > 100, # Only show progress for large batches
397
+ normalize_embeddings=True # Normalize for better similarity computation
398
+ )
399
+ return [emb.tolist() if hasattr(emb, 'tolist') else list(emb) for emb in embeddings]
400
+
401
+
402
+ def create_model_service(skip_embedder: bool = False, **kwargs) -> ModelServiceInterface:
403
+ """
404
+ Factory function to create the appropriate ModelService based on embedder_type.
405
+
406
+ Args:
407
+ skip_embedder (bool): If True, skip loading the embedding model (for keyword-only search).
408
+ **kwargs: Additional arguments including 'embedder_type' ('openai' or 'sentence-transformers')
409
+ and optional 'model_kwargs' dict which can override any env var defaults.
410
+ Returns:
411
+ ModelServiceInterface: An instance of the appropriate ModelService
412
+ """
413
+ model_kwargs = kwargs.pop('model_kwargs', None)
414
+ embedder_type = kwargs.pop('embedder_type', 'openai')
415
+
416
+ if embedder_type == 'openai':
417
+ return OpenAIModelService(model_kwargs=model_kwargs, **kwargs)
418
+ elif embedder_type == 'sentence-transformers':
419
+ return SentenceTransformersModelService(model_kwargs=model_kwargs, skip_embedder=skip_embedder, **kwargs)
420
+ else:
421
+ logging.getLogger(LOGGER_NAME).warning(
422
+ f'Unknown embedder type: {embedder_type}, defaulting to OpenAI'
423
+ )
424
+ return OpenAIModelService(model_kwargs=model_kwargs, **kwargs)
RepoKnowledgeGraphLib/Node.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, List
2
+ from dataclasses import dataclass, field, asdict
3
+
4
+ from .Entity import Entity
5
+
6
+ @dataclass
7
+ class Node:
8
+ id: str = ''
9
+ name: str = ''
10
+ node_type: str = ''
11
+ description: Optional[str] = None
12
+ declared_entities: List[dict] = field(default_factory=list) # Classes, functions, variables
13
+ called_entities: List[str] = field( default_factory=list) # Classes, functions, variables, but also external libraries
14
+
15
+ def dict(self):
16
+ return {k: str(v) for k, v in asdict(self).items()}
17
+
18
+ @dataclass
19
+ class DirectoryNode(Node):
20
+ path: str = ''
21
+ node_type: str = 'directory'
22
+
23
+
24
+ @dataclass
25
+ class FileNode(Node):
26
+ path: str = ''
27
+ content: str = ''
28
+ node_type: str = 'file'
29
+ language : str = ''
30
+
31
+
32
+ @dataclass
33
+ class ChunkNode(FileNode):
34
+ node_type: str = 'chunk'
35
+ order_in_file: int = field(default_factory=int)
36
+ embedding : list = None
37
+
38
+ def get_field_to_embed(self) -> Optional[str]:
39
+ # Use description if available, otherwise fall back to content
40
+ # This ensures we always have something meaningful to embed
41
+ if self.description and self.description.strip():
42
+ return self.description
43
+ return self.content
44
+
45
+
46
+ @dataclass
47
+ class EntityNode(Node):
48
+ entity_type: str = ''
49
+ declaring_chunk_ids: List[str] = field(default_factory=list)
50
+ calling_chunk_ids: List[str] = field(default_factory=list)
51
+ aliases: List[str] = field(default_factory=list) # All possible aliases for this entity
52
+ node_type: str = 'entity'
53
+
54
+ def __post_init__(self):
55
+ # Use entity_name (stored in name field) as the id if id is not set
56
+ if not self.id and self.name:
57
+ self.id = self.name
58
+
59
+ def dict(self):
60
+ return {k: str(v) for k, v in asdict(self).items()}
61
+
62
+ def get_field_to_embed(self) -> Optional[str]:
63
+ return self.name
RepoKnowledgeGraphLib/QuestionMaker.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import asyncio
3
+ from tqdm import tqdm
4
+
5
+ from .RepoKnowledgeGraph import RepoKnowledgeGraph
6
+ from .ModelService import ModelService
7
+ from .utils.logger_utils import setup_logger
8
+ from .utils.chunk_utils import organize_chunks_by_file_name, join_organized_chunks, extract_filename_from_chunk
9
+ from .Node import ChunkNode
10
+
11
+ LOGGER_NAME = "QUESTION_MAKER_LOGGER"
12
+
13
+ class QuestionMaker:
14
+ """
15
+ The QuestionMaker class is responsible for generating code comprehension questions and answers
16
+ based on code chunks and knowledge graphs. It leverages a language model service to formulate
17
+ questions and answers that test deep understanding of code, focusing on mechanisms, design decisions,
18
+ and subtle behaviors. It supports generating questions for neighboring code chunks as well as for
19
+ specific entities (e.g., functions, classes) that are both declared and called in the codebase.
20
+ """
21
+ def __init__(self):
22
+ """
23
+ Initializes the QuestionMaker, sets up logging, and instantiates the model service.
24
+ """
25
+ setup_logger(LOGGER_NAME)
26
+ self.logger = logging.getLogger(LOGGER_NAME)
27
+ self.model_service = ModelService()
28
+
29
+
30
+ def generate_questions_answers(self, candidate_chunks:dict) -> list:
31
+ """
32
+ Placeholder for generating questions and answers from candidate chunks.
33
+ Args:
34
+ candidate_chunks (dict): Dictionary mapping chunk groups to process.
35
+ Returns:
36
+ list: List of question-answer pairs.
37
+ """
38
+ pass
39
+
40
+ def test_chunk_sensibility(self, knowledge_graph: RepoKnowledgeGraph) -> list:
41
+ """
42
+ Placeholder for testing the sensibility of code chunks in the knowledge graph.
43
+ Args:
44
+ knowledge_graph (RepoKnowledgeGraph): The knowledge graph to test.
45
+ Returns:
46
+ list: List of results or metrics.
47
+ """
48
+ pass
49
+
50
+
51
+ async def make_n_neighbouring_chunk_questions_async(self, knowledge_graph: RepoKnowledgeGraph) -> list:
52
+ """
53
+ Generates questions and answers for all possible groups of n directly neighboring code chunks
54
+ in each file of the knowledge graph. This helps assess understanding of code that spans multiple
55
+ adjacent chunks, such as related functions or code blocks.
56
+
57
+ Args:
58
+ knowledge_graph (RepoKnowledgeGraph): The knowledge graph to generate questions from.
59
+ Returns:
60
+ list: A list of dictionaries, each containing a question, answer, the involved chunks, and category.
61
+ """
62
+ file_nodes = knowledge_graph.get_all_files()
63
+ # create candidate chunks dictionary
64
+ candidate_chunks = []
65
+ for file_node in file_nodes:
66
+ self.logger.info(f"Processing file node: {file_node}")
67
+ chunks = knowledge_graph.get_chunks_of_file(file_node.id)
68
+ num_chunks = len(chunks)
69
+ # For each n, collect all n-sized tuples of directly neighbouring chunks
70
+ for n in range(2, num_chunks + 1):
71
+ for i in range(num_chunks - n + 1):
72
+ # Only directly neighbouring chunks
73
+ candidate_chunks.append(list(chunks[i:i+n]))
74
+ # generate questions and answers from candidate chunks in parallel, in batches of 15
75
+
76
+ async def process_chunk_group(chunks):
77
+ """
78
+ Helper coroutine to generate a question and answer for a specific group of neighboring chunks.
79
+ Args:
80
+ chunks (list): The list of code chunks to generate the question and answer from.
81
+ Returns:
82
+ dict: Contains question, answer, chunks, and category.
83
+ """
84
+ question = await self._generate_neighboring_question_from_chunks_async(chunks)
85
+ answer = await self.answer_question_about_chunks_async(chunks, question)
86
+ return {
87
+ 'question': question,
88
+ 'clean_question': question,
89
+ 'answer': answer,
90
+ 'chunks': [chunk.dict() for chunk in chunks],
91
+ 'category': 'neighbouring_chunks'
92
+ }
93
+ # Batch processing in groups of 15 with tqdm
94
+ batch_size = 15
95
+ results = []
96
+ total = len(candidate_chunks)
97
+ for i in tqdm(range(0, total, batch_size), desc="Generating neighbouring chunk questions", unit="batch"):
98
+ batch = candidate_chunks[i:i+batch_size]
99
+ tasks = [process_chunk_group(chunks) for chunks in batch]
100
+ batch_results = []
101
+ for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Questions in batch", leave=False):
102
+ batch_results.append(await coro)
103
+ results.extend(batch_results)
104
+ return results
105
+
106
+ async def make_entity_declaration_call_specific_questions_async(self, knowledge_graph: RepoKnowledgeGraph) -> list:
107
+ """
108
+ Generates questions and answers about specific entities (e.g., functions, classes) that have both
109
+ a declaration and at least one call site in the knowledge graph. Focuses on cross-file references
110
+ by default.
111
+
112
+ Args:
113
+ knowledge_graph (RepoKnowledgeGraph): The knowledge graph to generate questions from.
114
+ Returns:
115
+ list: A list of dictionaries, each containing a question, answer, entity, involved chunks, and category.
116
+ """
117
+ self.logger.info("Generating entity-specific questions.")
118
+ candidate_pairs = self.get_entities_with_declaration_and_calling(knowledge_graph)
119
+
120
+ async def process_entity_pair(pair):
121
+ """
122
+ Helper coroutine to generate a question and answer for a specific entity's declaration and call site.
123
+ Args:
124
+ pair (dict): Contains entity name, declaring_chunk_id, and calling_chunk_id.
125
+ Returns:
126
+ dict: Contains question, answer, entity, chunks, and category.
127
+ """
128
+ entity_name = pair['entity']
129
+ chunks = [knowledge_graph[pair['declaring_chunk_id']], knowledge_graph[pair['calling_chunk_id']]]
130
+ question = await self.make_entity_specific_question_async(chunks, entity_name)
131
+ answer = await self.answer_question_about_chunks_async(chunks, question)
132
+ return {
133
+ 'question': question,
134
+ 'clean_question': question,
135
+ 'answer': answer,
136
+ 'entity': entity_name,
137
+ 'chunks': [chunk.dict() for chunk in chunks],
138
+ 'category': 'entity_declaration_call_specific'
139
+ }
140
+
141
+ # Batch processing with tqdm
142
+ batch_size = 15
143
+ results = []
144
+ total = len(candidate_pairs)
145
+ for i in tqdm(range(0, total, batch_size), desc="Generating entity-specific questions", unit="batch"):
146
+ batch = candidate_pairs[i:i+batch_size]
147
+ tasks = [process_entity_pair(pair) for pair in batch]
148
+ batch_results = []
149
+ for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Questions in batch", leave=False):
150
+ batch_results.append(await coro)
151
+ results.extend(batch_results)
152
+ return results
153
+
154
+ async def make_interacting_entities_specific_questions_async(self, entity_A:str, entity_B:str,
155
+ decl_chunk_A: ChunkNode, decl_chunk_B: ChunkNode,
156
+ call_chunk: ChunkNode) -> str:
157
+ """
158
+ Generates a question and answer about two entities that interact in the same chunk.
159
+ Each entity has a declaration and at least one call site, and the question focuses on their interaction.
160
+
161
+ Args:
162
+ entity_A (str): Name of the first entity.
163
+ entity_B (str): Name of the second entity.
164
+ decl_chunk_A (str): Chunk of the declaration of entity A.
165
+ decl_chunk_B (str): Chunk of the declaration of entity B.
166
+ call_chunk (str): Chunk ID where both entities interact.
167
+ Returns:
168
+ str: the generated question as plain text.
169
+ """
170
+ entity_A_definition_code = decl_chunk_A.content
171
+ entity_B_definition_code = decl_chunk_B.content
172
+ entity_interaction_code = call_chunk.content
173
+
174
+ prompt = f"""You are given two code entities, {entity_A} and {entity_B}, along with a snippet where they interact.
175
+ Your task is to write **one clear and concise question** about their relationship.
176
+
177
+ ### Input:
178
+ * {entity_A} Definition Code:
179
+ {entity_A_definition_code}
180
+
181
+ * {entity_B} Definition Code:
182
+ {entity_B_definition_code}
183
+
184
+ * Interaction Code (where they interact):
185
+ {entity_interaction_code}
186
+
187
+ ### Guidelines:
188
+ * Ask about design, abstraction, dependencies, or side effects.
189
+ * The question should highlight something a developer might consider when reviewing or improving the code.
190
+ * Keep the question short and direct so it can be answered briefly.
191
+ * Do not explain the code or provide answers.
192
+
193
+ ### Output:
194
+ **Question**: <your question here>
195
+ """
196
+
197
+
198
+ initial_question = await self.model_service.query_async(prompt=prompt)
199
+ return await self.extract_question_from_generated_text_async(generated_text=initial_question)
200
+
201
+ def get_all_candidate_pairs_and_triplets(self, knowledge_graph: RepoKnowledgeGraph) -> list:
202
+
203
+ candidate_triplets = []
204
+ candidate_pairs = []
205
+
206
+ interacting_entity_triplets = self.get_interacting_entity_triplets(knowledge_graph)
207
+ for triplet in interacting_entity_triplets:
208
+ chunks = [
209
+ knowledge_graph[triplet['decl_chunk_A']],
210
+ knowledge_graph[triplet['decl_chunk_B']],
211
+ knowledge_graph[triplet['call_chunk']]
212
+ ]
213
+ candidate_triplets.append({
214
+ 'entities': (triplet['entity_A'], triplet['entity_B']),
215
+ 'chunks': [chunk.dict() for chunk in chunks],
216
+ 'category': 'interacting_entities'
217
+ })
218
+
219
+ declaration_calling_pairs = self.get_entities_with_declaration_and_calling(knowledge_graph)
220
+ for pair in declaration_calling_pairs:
221
+ chunks = [knowledge_graph[pair['declaring_chunk_id']], knowledge_graph[pair['calling_chunk_id']]]
222
+ candidate_pairs.append({
223
+ 'entity': pair['entity'],
224
+ 'chunks': [chunk.dict() for chunk in chunks],
225
+ 'category': 'entity_declaration_call_specific'
226
+ })
227
+
228
+ return candidate_pairs, candidate_triplets
229
+
230
+
231
+
232
+ async def make_interacting_entity_questions_async(self, knowledge_graph: RepoKnowledgeGraph) -> list:
233
+ """
234
+ Generates questions and answers about pairs of entities that interact in the same chunk.
235
+ Each entity has a declaration and at least one call site, and the question focuses on their interaction.
236
+
237
+ Args:
238
+ knowledge_graph (RepoKnowledgeGraph): The knowledge graph to generate questions from.
239
+ Returns:
240
+ list: A list of dictionaries, each containing a question, answer, entities, involved chunks, and category.
241
+ """
242
+ self.logger.info("Generating interacting entity questions.")
243
+ triplets = self.get_interacting_entity_triplets(knowledge_graph)
244
+
245
+ async def process_triplet(triplet):
246
+ """
247
+ Helper coroutine to generate a question and answer for a specific interacting entity triplet.
248
+ Args:
249
+ triplet (dict): Contains entity_A, entity_B, decl_chunk_A, decl_chunk_B, and call_chunk.
250
+ Returns:
251
+ dict: Contains question, answer, entities, chunks, and category.
252
+ """
253
+ chunks = [
254
+ knowledge_graph[triplet['decl_chunk_A']],
255
+ knowledge_graph[triplet['decl_chunk_B']],
256
+ knowledge_graph[triplet['call_chunk']]
257
+ ]
258
+ question = await self.make_interacting_entities_specific_questions_async(entity_A=triplet['entity_A'],
259
+ entity_B=triplet['entity_B'],
260
+ decl_chunk_A=knowledge_graph[triplet['decl_chunk_A']],
261
+ decl_chunk_B=knowledge_graph[triplet['decl_chunk_B']],
262
+ call_chunk=knowledge_graph[triplet['call_chunk']])
263
+ answer = await self.answer_question_about_chunks_async(chunks, question)
264
+ return {
265
+ 'question': question,
266
+ 'clean_question': question,
267
+ 'answer': answer,
268
+ 'entities': (triplet['entity_A'], triplet['entity_B']),
269
+ 'chunks': [chunk.dict() for chunk in chunks],
270
+ 'category': 'interacting_entities'
271
+ }
272
+
273
+ # Batch processing with tqdm
274
+ batch_size = 15
275
+ results = []
276
+ total = len(triplets)
277
+ for i in tqdm(range(0, total, batch_size), desc="Generating interacting entity questions", unit="batch"):
278
+ batch = triplets[i:i+batch_size]
279
+ tasks = [process_triplet(triplet) for triplet in batch]
280
+ batch_results = []
281
+ for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Questions in batch", leave=False):
282
+ batch_results.append(await coro)
283
+ results.extend(batch_results)
284
+ return results
285
+
286
+ async def _generate_neighboring_question_from_chunks_async(self, chunks: list) -> str:
287
+ """
288
+ Generates a single code comprehension question for a group of code chunks using the model service.
289
+ The question is designed to probe deep understanding of the code's mechanisms, design, or pitfalls.
290
+
291
+ Args:
292
+ chunks (list): The list of code chunks to generate the question from.
293
+ Returns:
294
+ str: The generated question as plain text.
295
+ """
296
+ organized_chunks = organize_chunks_by_file_name(chunks)
297
+ joined_chunks = join_organized_chunks(organized_chunks)
298
+
299
+ system_prompt = """You are an expert in evaluating code comprehension. The user will provide, in the next message, the content of a code submission (in any programming language). Your goal is to analyze this code, identify its critical, subtle, or obscure aspects, and generate **one relevant question in English** to ask someone in order to assess their understanding of the code.
300
+
301
+ This question should focus on:
302
+
303
+ * essential mechanisms of how the code works,
304
+ * important design decisions,
305
+ * potential pitfalls or unexpected behaviors,
306
+ * or any aspect that requires deep comprehension.
307
+
308
+ The goal is to test whether the person has **truly understood** the codeβ€”not just skimmed through it.
309
+
310
+ Respond with **only one question**, in plain text. Do not include any explanation, comment, or wrapper (e.g., no dictionaries, no lists).
311
+ """
312
+ initial_question = await self.model_service.query_with_instructions_async(instructions=system_prompt, prompt=joined_chunks)
313
+ return await self.extract_question_from_generated_text_async(generated_text=initial_question)
314
+
315
+ async def answer_question_about_chunks_async(self, chunks: list, question: str) -> str:
316
+ """
317
+ Generates an answer to a code comprehension question about a group of code chunks using the model service.
318
+ The answer should demonstrate deep understanding and cover mechanisms, design, and pitfalls.
319
+
320
+ Args:
321
+ chunks (list): The list of code chunks to answer the question about.
322
+ question (str): The question to answer.
323
+ Returns:
324
+ str: The generated answer as plain text.
325
+ """
326
+ organized_chunks = organize_chunks_by_file_name(chunks)
327
+ joined_chunks = join_organized_chunks(organized_chunks)
328
+
329
+ system_prompt = """You are an expert in evaluating code comprehension. The user will provide, in the next message, the content of a code submission (in any programming language) and a question about it. Your goal is to analyze this code, identify its critical, subtle, or obscure aspects, and generate **one relevant answer in English** to the question.
330
+
331
+ This answer should focus on:
332
+
333
+ * essential mechanisms of how the code works,
334
+ * important design decisions,
335
+ * potential pitfalls or unexpected behaviors,
336
+ * or any aspect that requires deep comprehension.
337
+ The goal is to provide a clear and thorough answer that demonstrates a deep understanding of the code.
338
+ """
339
+
340
+ return await self.model_service.query_with_instructions_async(instructions=system_prompt, prompt=joined_chunks + "\n\n" + question)
341
+
342
+ async def make_entity_specific_question_async(self, chunks: list, entity_name:str):
343
+ """
344
+ Generates a question about a specific entity (e.g., function, class) in the context of the provided code chunks.
345
+ The question is designed to probe understanding of the entity's purpose, behavior, and interactions.
346
+
347
+ Args:
348
+ chunks (list): The list of code chunks to generate the question from.
349
+ entity_name (str): The name of the entity to focus on.
350
+ Returns:
351
+ str: The generated question as plain text.
352
+ """
353
+ organized_chunks = organize_chunks_by_file_name(chunks)
354
+ joined_chunks = join_organized_chunks(organized_chunks)
355
+
356
+ system_prompt = f"""You will be given one or more code snippets, possibly from multiple files.
357
+
358
+ A specific entity (such as a class, function, or variable) will be identified.
359
+
360
+ ---
361
+
362
+ ## Entity of Focus: {entity_name}
363
+
364
+ ### Task:
365
+ * Write **one clear and concise question** about this entity.
366
+ * The question should highlight something a developer might consider, such as its purpose, behavior, interactions, or potential improvements.
367
+
368
+ ### Guidelines:
369
+ * Keep the question short and direct.
370
+ * Do not explain the code or give an answer.
371
+
372
+ ### Output:
373
+ **Question**: <your question here>
374
+ """
375
+
376
+ initial_question= await self.model_service.query_with_instructions_async(instructions=system_prompt, prompt=joined_chunks)
377
+ return await self.extract_question_from_generated_text_async(generated_text=initial_question)
378
+
379
+ def get_entities_with_declaration_and_calling(self, knowledge_graph: RepoKnowledgeGraph, cross_file_only: bool = True) -> list:
380
+ """
381
+ Finds all entities in the knowledge graph that have both a declaration and at least one call site.
382
+ Optionally restricts to cases where the declaration and call are in different files (cross-file).
383
+
384
+ Args:
385
+ knowledge_graph (RepoKnowledgeGraph): The knowledge graph to search in.
386
+ cross_file_only (bool): If True, only consider cross-file declaration/call pairs.
387
+ Returns:
388
+ list: List of dictionaries with 'entity', 'declaring_chunk_id', and 'calling_chunk_id'.
389
+ """
390
+ candidate_pairs = []
391
+ entities = knowledge_graph.entities
392
+ for entity_name in entities:
393
+ entity = entities[entity_name]
394
+ if len(entity['declaring_chunk_ids']) and len(entity['calling_chunk_ids']):
395
+ found = False
396
+ for declaring_chunk_id in entity['declaring_chunk_ids']:
397
+ for calling_chunk_id in entity['calling_chunk_ids']:
398
+ if declaring_chunk_id != calling_chunk_id:
399
+ if cross_file_only and extract_filename_from_chunk(knowledge_graph[declaring_chunk_id]) == extract_filename_from_chunk(knowledge_graph[calling_chunk_id]):
400
+ continue
401
+ else:
402
+ candidate_pairs.append({'entity': entity_name, 'declaring_chunk_id' : declaring_chunk_id, 'calling_chunk_id': calling_chunk_id})
403
+ found = True
404
+ break
405
+ if found:
406
+ break
407
+ return candidate_pairs
408
+
409
+ def get_interacting_entity_triplets(self, knowledge_graph: RepoKnowledgeGraph) -> list:
410
+ """
411
+ Finds triplets of chunk ids such that:
412
+ - Two entities (A, B) are interacting in the same chunk (call_chunk)
413
+ - Each entity has a declaring chunk (decl_chunk_A, decl_chunk_B)
414
+ - Both entities have non-empty declaring_chunk_ids and calling_chunk_ids
415
+
416
+ Returns:
417
+ list of dicts with keys:
418
+ 'entity_A', 'entity_B', 'decl_chunk_A', 'decl_chunk_B', 'call_chunk'
419
+ """
420
+ triplets = []
421
+ seen_pairs = set()
422
+ entities = knowledge_graph.entities
423
+ for entity_A_name, entity_A in entities.items():
424
+ if not entity_A['declaring_chunk_ids'] or not entity_A['calling_chunk_ids']:
425
+ continue
426
+ for entity_B_name, entity_B in entities.items():
427
+ if entity_A_name == entity_B_name:
428
+ continue
429
+ if not entity_B['declaring_chunk_ids'] or not entity_B['calling_chunk_ids']:
430
+ continue
431
+ pair_key = (entity_A_name, entity_B_name)
432
+ if pair_key in seen_pairs:
433
+ continue
434
+ # Find intersection of calling_chunk_ids
435
+ call_chunks = set(entity_A['calling_chunk_ids']) & set(entity_B['calling_chunk_ids'])
436
+ found = False
437
+ for call_chunk in call_chunks:
438
+ for decl_chunk_A in entity_A['declaring_chunk_ids']:
439
+ for decl_chunk_B in entity_B['declaring_chunk_ids']:
440
+ triplets.append({
441
+ 'entity_A': entity_A_name,
442
+ 'entity_B': entity_B_name,
443
+ 'decl_chunk_A': decl_chunk_A,
444
+ 'decl_chunk_B': decl_chunk_B,
445
+ 'call_chunk': call_chunk
446
+ })
447
+ seen_pairs.add(pair_key)
448
+ found = True
449
+ break
450
+ if found:
451
+ break
452
+ if found:
453
+ break
454
+ return triplets
455
+
456
+
457
+ async def extract_question_from_generated_text_async(self, generated_text: str) -> str:
458
+ """
459
+ Extracts the question from the generated text. The question is expected to be the last line of the text.
460
+
461
+ Args:
462
+ generated_text (str): The text generated by the model.
463
+ Returns:
464
+ str: The extracted question.
465
+ """
466
+
467
+ prompt = f"Extract only the question from the following text. Return the question exactly, with no extra words or labels:\n\n{generated_text}\n\n"
468
+ return await self.model_service.query_async(prompt=prompt)
469
+
470
+ def select_diverse_candidates(self, candidate_pairs, candidate_triplets, max_pairs=20, max_triplets=20):
471
+ """
472
+ Selects a limited number of pairs and triplets with maximum diversity in entity representation.
473
+ Args:
474
+ candidate_pairs (list): List of candidate pairs (dicts with 'entity', ...).
475
+ candidate_triplets (list): List of candidate triplets (dicts with 'entities', ...).
476
+ max_pairs (int): Maximum number of pairs to select.
477
+ max_triplets (int): Maximum number of triplets to select.
478
+ Returns:
479
+ (list, list): Selected pairs and triplets.
480
+ """
481
+ # Select pairs
482
+ selected_pairs = []
483
+ used_entities = set()
484
+ for pair in candidate_pairs:
485
+ entity = pair['entity']
486
+ if entity not in used_entities:
487
+ selected_pairs.append(pair)
488
+ used_entities.add(entity)
489
+ if len(selected_pairs) >= max_pairs:
490
+ break
491
+ # Select triplets
492
+ selected_triplets = []
493
+ used_entities_triplets = set()
494
+ for triplet in candidate_triplets:
495
+ entities = set(triplet['entities'])
496
+ if not entities & used_entities_triplets:
497
+ selected_triplets.append(triplet)
498
+ used_entities_triplets.update(entities)
499
+ if len(selected_triplets) >= max_triplets:
500
+ break
501
+ return selected_pairs, selected_triplets
502
+
503
+ async def transform_answser_into_mcq_answer_async(self, question, answer, chunks):
504
+ """
505
+ Transforms the question and answer into a format suitable for MCQ generation.
506
+ """
507
+ code = join_organized_chunks(organize_chunks_by_file_name(chunks))
508
+
509
+ prompt = f"""
510
+ You are an expert Python developer and technical writer. I will give you:
511
+
512
+ 1. A Python code snippet
513
+ 2. A question about that code
514
+ 3. A detailed answer to the question
515
+
516
+ Your task is to **sanitize** the answer. That means:
517
+
518
+ - Strip away all fluff, filler, and redundant explanation
519
+ - Focus only on what directly answers the question
520
+ - Make it **short, clear, and direct**, as if it were a correct MCQ answer
521
+ - Prefer concise phrases or a single clear sentence over paragraph explanations
522
+ - Keep any necessary technical detail, but no more than needed
523
+
524
+ Do **not** repeat the question. Do **not** rephrase the code. Just give the concise, final answer.
525
+
526
+ - **Input Code**:
527
+ {code}
528
+
529
+ - **Question**:
530
+ {question}
531
+
532
+ - **Original Answer**:
533
+ {answer}
534
+
535
+ - **Sanitized Answer**:
536
+ """
537
+ return await self.model_service.query_async(prompt)
538
+
RepoKnowledgeGraphLib/RepoKnowledgeGraph.py ADDED
@@ -0,0 +1,1608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import networkx as nx
2
+ import json
3
+ import os
4
+ import asyncio
5
+ import nest_asyncio
6
+ import tqdm
7
+ # from pathlib import Path
8
+ import os.path
9
+ import tempfile
10
+ import subprocess
11
+ from typing import List, Optional, Dict
12
+ import logging
13
+ import urllib.parse
14
+
15
+ from .ModelService import create_model_service
16
+ from .Node import Node, DirectoryNode, FileNode, ChunkNode, EntityNode
17
+ from .CodeParser import CodeParser
18
+ from .EntityExtractor import HybridEntityExtractor
19
+ from .CodeIndex import CodeIndex
20
+ from .utils.logger_utils import setup_logger
21
+ from .utils.parsing_utils import read_directory_files_recursively, get_language_from_filename
22
+ from .utils.path_utils import prepare_input_path, build_entity_alias_map, resolve_entity_call
23
+ from .EntityChunkMapper import EntityChunkMapper
24
+
25
+ LOGGER_NAME = 'REPO_KNOWLEDGE_GRAPH_LOGGER'
26
+
27
+ MODEL_SERVICE_TYPES = ['openai', 'sentence-transformers']
28
+
29
+
30
+ # A RepoKnowledgeGraph is a weighted DAG based on a tree-structure with added edges
31
+ class RepoKnowledgeGraph:
32
+ """
33
+ RepoKnowledgeGraph builds a knowledge graph of a code repository.
34
+ It parses source files, extracts code entities and relationships, and organizes them
35
+ into a directed acyclic graph (DAG) with additional semantic edges.
36
+
37
+ Use `from_path()` or `load_graph_from_file()` to create instances.
38
+ """
39
+
40
+ def __init__(self):
41
+ """
42
+ Private constructor. Use from_path() or load_graph_from_file() instead.
43
+ """
44
+ raise RuntimeError(
45
+ "Cannot instantiate RepoKnowledgeGraph directly. "
46
+ "Use RepoKnowledgeGraph.from_path() or RepoKnowledgeGraph.load_graph_from_file() instead."
47
+ )
48
+
49
+ def _initialize(self, model_service_kwargs: dict, code_index_kwargs: Optional[dict] = None):
50
+ """Internal initialization method."""
51
+ setup_logger(LOGGER_NAME)
52
+ self.logger = logging.getLogger(LOGGER_NAME)
53
+ self.logger.info('Initializing RepoKnowledgeGraph instance.')
54
+ self.code_parser = CodeParser()
55
+
56
+ # Determine if we should skip loading the embedder based on index_type
57
+ index_type = (code_index_kwargs or {}).get('index_type', 'hybrid')
58
+ skip_embedder = index_type == 'keyword-only'
59
+ if skip_embedder:
60
+ self.logger.info('Using keyword-only index, skipping embedder initialization')
61
+
62
+ self.model_service = create_model_service(skip_embedder=skip_embedder, **model_service_kwargs)
63
+ self.entities = {}
64
+ self.graph = nx.DiGraph()
65
+ self.knowledge_graph = nx.DiGraph()
66
+ self.code_index = None
67
+ self.entity_extractor = HybridEntityExtractor()
68
+
69
+ def __iter__(self):
70
+ # Yield only the 'data' attribute from each node
71
+ return (node_data['data'] for _, node_data in self.graph.nodes(data=True))
72
+
73
+ def __getitem__(self, node_id):
74
+ return self.graph.nodes[node_id]['data']
75
+
76
+
77
+ @classmethod
78
+ def from_path(cls, path: str, skip_dirs: Optional[list] = None, index_nodes: bool = True, describe_nodes=False,
79
+ extract_entities: bool = False, model_service_kwargs: Optional[dict] = None, code_index_kwargs: Optional[dict] = None):
80
+ if skip_dirs is None:
81
+ skip_dirs = []
82
+ if model_service_kwargs is None:
83
+ model_service_kwargs = {}
84
+ """
85
+ Alternative constructor to build a RepoKnowledgeGraph from a path, with options to skip directories
86
+ and control entity extraction and node description.
87
+
88
+ Args:
89
+ path (str): Path to the root of the code repository.
90
+ skip_dirs (list): List of directory names to skip.
91
+ index_nodes (bool): Whether to build a code index.
92
+ describe_nodes (bool): Whether to generate descriptions for code chunks.
93
+ extract_entities (bool): Whether to extract entities from code.
94
+
95
+ Returns:
96
+ RepoKnowledgeGraph: The constructed knowledge graph.
97
+ """
98
+ instance = cls.__new__(cls) # Create instance without calling __init__
99
+ instance._initialize(model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs)
100
+
101
+ instance.logger.info(f"Preparing to build knowledge graph from path: {path}")
102
+
103
+ prepared_path = prepare_input_path(path)
104
+ instance.logger.debug(f"Prepared input path: {prepared_path}")
105
+
106
+ # Handle running event loop (e.g., in Jupyter)
107
+ try:
108
+ loop = asyncio.get_running_loop()
109
+ except RuntimeError:
110
+ loop = None
111
+
112
+ if loop and loop.is_running():
113
+ instance.logger.debug("Detected running event loop, applying nest_asyncio.")
114
+ nest_asyncio.apply()
115
+ task = instance._initial_parse_path_async(prepared_path, skip_dirs=skip_dirs, index_nodes=index_nodes,
116
+ describe_nodes=describe_nodes, extract_entities=extract_entities)
117
+ loop.run_until_complete(task)
118
+ else:
119
+ instance.logger.debug("No running event loop, using asyncio.run.")
120
+ asyncio.run(instance._initial_parse_path_async(prepared_path, skip_dirs=skip_dirs, index_nodes=index_nodes,
121
+ describe_nodes=describe_nodes,
122
+ extract_entities=extract_entities))
123
+
124
+ instance.logger.info("Parsing files and building initial nodes...")
125
+ instance.logger.info("Initial parse and node creation complete. Building relationships between nodes...")
126
+ instance._build_relationships()
127
+
128
+ if index_nodes:
129
+ instance.logger.info("Building code index for all nodes in the graph...")
130
+ instance.code_index = CodeIndex(list(instance), model_service=instance.model_service, **(code_index_kwargs or {}))
131
+
132
+ instance.logger.info("Knowledge graph construction from path completed successfully.")
133
+ return instance
134
+
135
+ @classmethod
136
+ def from_repo(
137
+ cls,
138
+ repo_url: str,
139
+ skip_dirs: Optional[list] = None,
140
+ index_nodes: bool = True,
141
+ describe_nodes: bool = False,
142
+ extract_entities: bool = False,
143
+ model_service_kwargs: Optional[dict] = None,
144
+ code_index_kwargs: Optional[dict]=None,
145
+ github_token: Optional[str] = None,
146
+ allow_unauthenticated_clone: bool = True,
147
+ ):
148
+ """
149
+ Alternative constructor to build a RepoKnowledgeGraph from a remote git repository URL.
150
+
151
+ Args:
152
+ repo_url (str): Git repository URL (SSH or HTTPS).
153
+ skip_dirs (list): List of directory names to skip.
154
+ index_nodes (bool): Whether to build a code index.
155
+ describe_nodes (bool): Whether to generate descriptions for code chunks.
156
+ extract_entities (bool): Whether to extract entities from code.
157
+ github_token (str, optional): Personal access token to access private GitHub repos.
158
+ If not provided, the method will look for the `GITHUB_OAUTH_TOKEN` environment variable.
159
+ allow_unauthenticated_clone (bool): If True, attempt to clone without a token when none is provided.
160
+ If False, raise an error when no token is available.
161
+
162
+ Returns:
163
+ RepoKnowledgeGraph: The constructed knowledge graph.
164
+ """
165
+ if skip_dirs is None:
166
+ skip_dirs = []
167
+ if model_service_kwargs is None:
168
+ model_service_kwargs = {}
169
+
170
+ instance = cls.__new__(cls)
171
+ instance._initialize(model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs)
172
+
173
+ instance.logger.info(f"Starting knowledge graph build from remote repository: {repo_url}")
174
+
175
+ # Determine token
176
+ token = github_token or os.environ.get('GITHUB_OAUTH_TOKEN')
177
+
178
+ with tempfile.TemporaryDirectory() as tmpdirname:
179
+ clone_url = repo_url
180
+ try:
181
+ if repo_url.startswith('git@'):
182
+ # Convert git@github.com:owner/repo.git -> https://github.com/owner/repo.git
183
+ clone_url = repo_url.replace(':', '/').split('git@')[-1]
184
+ clone_url = f'https://{clone_url}'
185
+
186
+ if token and clone_url.startswith('https://'):
187
+ encoded_token = urllib.parse.quote(token, safe='')
188
+ clone_url = clone_url.replace('https://', f'https://{encoded_token}@')
189
+ elif not token and not allow_unauthenticated_clone:
190
+ raise ValueError(
191
+ "GitHub token not provided and unauthenticated clone is disabled. "
192
+ "Set allow_unauthenticated_clone=True or provide a token."
193
+ )
194
+
195
+ instance.logger.debug(f"Running git clone: {clone_url} -> {tmpdirname}")
196
+ subprocess.run(['git', 'clone', clone_url, tmpdirname], check=True)
197
+
198
+ except Exception as e:
199
+ instance.logger.error(f"Failed to clone repository {repo_url} using URL {clone_url}: {e}")
200
+ raise
201
+
202
+ instance.logger.info(f"Repository successfully cloned to: {tmpdirname}")
203
+
204
+ return cls.from_path(
205
+ tmpdirname,
206
+ skip_dirs=skip_dirs,
207
+ index_nodes=index_nodes,
208
+ describe_nodes=describe_nodes,
209
+ extract_entities=extract_entities,
210
+ model_service_kwargs=model_service_kwargs,
211
+ code_index_kwargs=code_index_kwargs
212
+ )
213
+
214
+ async def _initial_parse_path_async(self, path: str, skip_dirs: list, index_nodes=True, describe_nodes=True,
215
+ extract_entities: bool = True):
216
+ self.logger.info(f"Beginning async parsing of repository at path: {path}")
217
+ """
218
+ Orchestrates the parsing and graph construction process:
219
+ 1. Reads files and splits into chunks.
220
+ 2. Extracts entities and relationships.
221
+ 3. Builds chunk, file, directory, and root nodes.
222
+ 4. Aggregates entity information.
223
+
224
+ Args:
225
+ path (str): Root path to parse.
226
+ skip_dirs (list): Directories to skip.
227
+ index_nodes (bool): Whether to build code index.
228
+ describe_nodes (bool): Whether to generate descriptions.
229
+ extract_entities (bool): Whether to extract entities.
230
+ """
231
+
232
+ # --- Pass 1: Create ChunkNodes ---
233
+ level1_node_contents = read_directory_files_recursively(
234
+ path, skip_dirs=skip_dirs,
235
+ skip_pattern=r"(?:\.log$|\.json$|(?:^|/)(?:\.git|\.idea|__pycache__|\.cache)(?:/|$)|(?:^|/)(?:changelog|ChangeLog)(?:\.[a-z0-9]+)?$|\.cache$)"
236
+ )
237
+ self.logger.debug(f"Found {len(level1_node_contents)} files to process.")
238
+ self.logger.info("Chunk nodes creation step started.")
239
+ chunk_info = await self._create_chunk_nodes(
240
+ level1_node_contents, extract_entities, describe_nodes, index_nodes, root_path=path
241
+ )
242
+ self.logger.info("Chunk nodes creation step finished.")
243
+ self.logger.info("File nodes creation step started.")
244
+ file_info = self._create_file_nodes(
245
+ chunk_info, level1_node_contents
246
+ )
247
+ self.logger.info("File nodes creation step finished.")
248
+ self.logger.info("Directory nodes creation step started.")
249
+ dir_agg = self._create_directory_nodes(
250
+ file_info
251
+ )
252
+ self.logger.info("Directory nodes creation step finished.")
253
+ self.logger.info("Aggregating all nodes to root node.")
254
+ self._aggregate_to_root(dir_agg)
255
+ self.logger.info("Async parse and node aggregation fully complete.")
256
+
257
+ async def _create_chunk_nodes(self, level1_node_contents, extract_entities, describe_nodes, index_nodes, root_path=None):
258
+ self.logger.info(f"Starting chunk node creation for {len(level1_node_contents)} files.")
259
+ accepted_extensions = {'.py', '.c', '.cpp', '.h', '.hpp', '.java', '.js', '.ts', '.jsx', '.tsx', '.rs', '.html'}
260
+ chunk_info = {}
261
+ entity_mapper = EntityChunkMapper()
262
+ total_chunks = 0
263
+
264
+ # Use tqdm for progress bar over files
265
+ for file_path in tqdm.tqdm(level1_node_contents, desc="Processing files for chunk nodes"):
266
+ self.logger.debug(f"Processing file for chunk nodes: {file_path}")
267
+ full_path = os.path.normpath(file_path)
268
+ parts = full_path.split(os.sep)
269
+ _, ext = os.path.splitext(file_path)
270
+ is_code_file = ext.lower() in accepted_extensions
271
+
272
+ self.logger.debug(f"Parsing file: {file_path}")
273
+
274
+ # Parse file into chunks
275
+ parsed_content = self.code_parser.parse(file_name=file_path, file_content=level1_node_contents[file_path])
276
+ self.logger.debug(f"Parsed {len(parsed_content)} chunks from file: {file_path}")
277
+ total_chunks += len(parsed_content)
278
+
279
+ # Entity extraction logging
280
+ if extract_entities and is_code_file:
281
+ self.logger.debug(f"Extracting entities from code file: {file_path}")
282
+ try:
283
+ # Construct full path for entity extraction (needed for C/C++ include resolution)
284
+ extraction_file_path = os.path.join(root_path, file_path) if root_path else file_path
285
+
286
+ file_declared_entities, file_called_entities = self.entity_extractor.extract_entities(
287
+ code=level1_node_contents[file_path], file_name=extraction_file_path)
288
+ self.logger.debug(f"Extracted {len(file_declared_entities)} declared and {len(file_called_entities)} called entities from file: {file_path}")
289
+
290
+ chunk_declared_map, chunk_called_map = entity_mapper.map_entities_to_chunks(
291
+ file_declared_entities, file_called_entities, parsed_content, file_name=file_path)
292
+ self.logger.debug(f"Mapped entities to {len(parsed_content)} chunks for file: {file_path}")
293
+ except Exception as e:
294
+ self.logger.error(f"Error extracting entities from {file_path}: {e}")
295
+ file_declared_entities, file_called_entities = [], []
296
+ chunk_declared_map = {i: [] for i in range(len(parsed_content))}
297
+ chunk_called_map = {i: [] for i in range(len(parsed_content))}
298
+ else:
299
+ self.logger.debug(f"Skipping entity extraction for non-code file: {file_path}")
300
+ file_declared_entities, file_called_entities = [], []
301
+ chunk_declared_map = {i: [] for i in range(len(parsed_content))}
302
+ chunk_called_map = {i: [] for i in range(len(parsed_content))}
303
+
304
+ chunk_tasks = []
305
+ for i, chunk in enumerate(parsed_content):
306
+ chunk_id = f'{file_path}_{i}'
307
+ self.logger.debug(f"Scheduling processing for chunk {chunk_id} of file {file_path}")
308
+
309
+ async def process_chunk(i=i, chunk=chunk, chunk_id=chunk_id):
310
+ self.logger.debug(f"Creating chunk node: {chunk_id}")
311
+ declared_entities = chunk_declared_map.get(i, [])
312
+ called_entities = chunk_called_map.get(i, [])
313
+
314
+ # FIRST PASS: Register all declared entities with aliases
315
+ # Build temporary alias map for checking existing entities
316
+ temp_alias_map = build_entity_alias_map(self.entities)
317
+
318
+ for entity in declared_entities:
319
+ name = entity.get("name")
320
+ if not name:
321
+ continue
322
+
323
+ # Check if this entity already exists under any of its aliases
324
+ entity_aliases = entity.get("aliases", [])
325
+ canonical_name = None
326
+
327
+ # First check if the name itself already exists or is an alias
328
+ if name in temp_alias_map:
329
+ canonical_name = temp_alias_map[name]
330
+ self.logger.debug(f"Entity '{name}' already exists as '{canonical_name}'")
331
+ else:
332
+ # Check if any of the entity's aliases match existing entities
333
+ for alias in entity_aliases:
334
+ if alias in temp_alias_map:
335
+ canonical_name = temp_alias_map[alias]
336
+ self.logger.debug(f"Entity '{name}' matches existing entity '{canonical_name}' via alias '{alias}'")
337
+ break
338
+
339
+ # If we found a match, use the canonical name; otherwise use the entity name
340
+ if canonical_name:
341
+ entity_key = canonical_name
342
+ else:
343
+ entity_key = name
344
+ self.logger.debug(f"Registering new declared entity '{name}' in chunk {chunk_id}")
345
+ self.entities[entity_key] = {
346
+ "declaring_chunk_ids": [],
347
+ "calling_chunk_ids": [],
348
+ "type": [],
349
+ "dtype": None,
350
+ "aliases": []
351
+ }
352
+ # Update temp alias map with new entity
353
+ temp_alias_map[entity_key] = entity_key
354
+
355
+ if chunk_id not in self.entities[entity_key]["declaring_chunk_ids"]:
356
+ self.entities[entity_key]["declaring_chunk_ids"].append(chunk_id)
357
+ entity_type = entity.get("type")
358
+ if entity_type and entity_type not in self.entities[entity_key]["type"]:
359
+ self.entities[entity_key]["type"].append(entity_type)
360
+ dtype = entity.get("dtype")
361
+ if dtype:
362
+ self.entities[entity_key]["dtype"] = dtype
363
+ # Store aliases (add new ones, avoiding duplicates)
364
+ for alias in [name] + entity_aliases:
365
+ if alias and alias not in self.entities[entity_key]["aliases"]:
366
+ self.entities[entity_key]["aliases"].append(alias)
367
+ temp_alias_map[alias] = entity_key # Update temp map
368
+ self.logger.debug(f"Declared entity '{name}' registered as '{entity_key}' in chunk {chunk_id} with aliases: {self.entities[entity_key]['aliases']}")
369
+
370
+
371
+ # Logging for node creation
372
+ if describe_nodes:
373
+ self.logger.info(f"Generating description for chunk {chunk_id}")
374
+ try:
375
+ description = await self.model_service.query_async(
376
+ f'Summarize this {get_language_from_filename(file_path)} code chunk in a few sentences: {chunk}')
377
+ except Exception as e:
378
+ self.logger.error(f"Error generating description for chunk {chunk_id}: {e}")
379
+ description = ''
380
+ else:
381
+ self.logger.debug(f"No description requested for chunk {chunk_id}")
382
+ description = ''
383
+
384
+ chunk_node = ChunkNode(
385
+ id=chunk_id,
386
+ name=chunk_id,
387
+ path=file_path,
388
+ content=chunk,
389
+ order_in_file=i,
390
+ called_entities=called_entities,
391
+ declared_entities=declared_entities,
392
+ language=get_language_from_filename(file_path),
393
+ description=description,
394
+ )
395
+ self.logger.debug(f"Chunk node created: {chunk_id}")
396
+
397
+ # NOTE: Embeddings are now deferred to CodeIndex for efficient batch processing
398
+ # This avoids the slow one-at-a-time embedding during chunk creation
399
+ chunk_node.embedding = None
400
+ return (chunk_id, chunk_node, declared_entities, called_entities)
401
+
402
+ chunk_tasks.append(process_chunk())
403
+
404
+ chunk_results = await asyncio.gather(*chunk_tasks)
405
+ self.logger.debug(f"Finished processing {len(chunk_results)} chunks for file {file_path}.")
406
+ chunk_info[file_path] = {
407
+ 'chunk_results': chunk_results,
408
+ 'file_declared_entities': file_declared_entities,
409
+ 'file_called_entities': file_called_entities
410
+ }
411
+
412
+ # Log summary
413
+ self.logger.info(f"Created {total_chunks} chunk nodes from {len(level1_node_contents)} files")
414
+
415
+ # SECOND PASS: Now that all declared entities are registered, resolve called entities
416
+ self.logger.info("Starting second pass: resolving called entities using alias map...")
417
+ alias_map = build_entity_alias_map(self.entities)
418
+ self.logger.info(f"Built alias map with {len(alias_map)} entries for resolution")
419
+
420
+ resolved_count = 0
421
+ for file_path, file_data in tqdm.tqdm(chunk_info.items(), desc="Resolving called entities"):
422
+ chunk_results = file_data['chunk_results']
423
+ for chunk_id, chunk_node, declared_entities, called_entities in chunk_results:
424
+ for called_name in called_entities:
425
+ # Skip empty or whitespace-only names
426
+ if not called_name or not called_name.strip():
427
+ continue
428
+
429
+ # Try to resolve this called entity to an existing declared entity using aliases
430
+ resolved_name = resolve_entity_call(called_name, alias_map)
431
+
432
+ # Use the resolved name if found, otherwise check if called_name is already an alias
433
+ if resolved_name:
434
+ entity_key = resolved_name
435
+ elif called_name in alias_map:
436
+ # The called_name itself is an alias of an existing entity
437
+ entity_key = alias_map[called_name]
438
+ else:
439
+ # No match found, use the original called name
440
+ entity_key = called_name
441
+
442
+ if entity_key not in self.entities:
443
+ self.logger.debug(f"Registering new called entity '{entity_key}' (called as '{called_name}') in chunk {chunk_id}")
444
+ self.entities[entity_key] = {
445
+ "declaring_chunk_ids": [],
446
+ "calling_chunk_ids": [],
447
+ "type": [],
448
+ "dtype": None,
449
+ "aliases": []
450
+ }
451
+ # Add called_name as an alias if it's different from entity_key
452
+ if called_name != entity_key:
453
+ self.entities[entity_key]["aliases"].append(called_name)
454
+ alias_map[called_name] = entity_key # Update alias map
455
+
456
+ if chunk_id not in self.entities[entity_key]["calling_chunk_ids"]:
457
+ self.entities[entity_key]["calling_chunk_ids"].append(chunk_id)
458
+
459
+ if resolved_name and resolved_name != called_name:
460
+ resolved_count += 1
461
+ self.logger.debug(f"Called entity '{called_name}' resolved to '{entity_key}' in chunk {chunk_id}")
462
+
463
+ self.logger.info(f"Resolved {resolved_count} entity calls to existing declarations via aliases")
464
+ self.logger.info("All chunk nodes have been created for all files.")
465
+ return chunk_info
466
+
467
+ def _create_file_nodes(self, chunk_info, level1_node_contents):
468
+ self.logger.info("Starting file node creation.")
469
+ """
470
+ For each file, aggregate chunk information and create FileNode objects.
471
+ This method remains mostly the same.
472
+ """
473
+
474
+ def merge_entities(target, source):
475
+ # Merge entity lists, avoiding duplicates by (name, type)
476
+ existing = set((e.get('name'), e.get('type')) for e in target)
477
+ for e in source:
478
+ k = (e.get('name'), e.get('type'))
479
+ if k not in existing:
480
+ target.append(e)
481
+ existing.add(k)
482
+
483
+ def merge_called_entities(target, source):
484
+ # Merge called entity lists, avoiding duplicates
485
+ existing = set(target)
486
+ for e in source:
487
+ if e not in existing:
488
+ target.append(e)
489
+ existing.add(e)
490
+
491
+ file_info = {}
492
+ for file_path, file_data in tqdm.tqdm(chunk_info.items(), desc="Creating file nodes"):
493
+ self.logger.info(f"Creating file node for: {file_path}")
494
+ parts = os.path.normpath(file_path).split(os.sep)
495
+
496
+ # Extract file-level entities and chunk results from the stored data
497
+ chunk_results = file_data['chunk_results']
498
+ file_declared_entities = list(file_data['file_declared_entities']) # Use file-level entities directly
499
+ file_called_entities = list(file_data['file_called_entities']) # Use file-level entities directly
500
+ chunk_ids = []
501
+
502
+ for chunk_id, chunk_node, declared_entities, called_entities in chunk_results:
503
+ self.logger.info(f"Adding chunk node {chunk_id} to graph for file {file_path}")
504
+ self.graph.add_node(chunk_id, data=chunk_node, level=2)
505
+ chunk_ids.append(chunk_id)
506
+ # Note: We're using file-level entities for the FileNode, so we don't need to merge from chunks
507
+ # The chunks already have their entities set correctly
508
+
509
+ file_node = FileNode(
510
+ id=file_path,
511
+ name=parts[-1],
512
+ path=file_path,
513
+ node_type='file',
514
+ content=level1_node_contents[file_path],
515
+ declared_entities=file_declared_entities,
516
+ called_entities=file_called_entities,
517
+ language=get_language_from_filename(file_path),
518
+ )
519
+
520
+ self.logger.debug(f"Adding file node {file_path} to graph.")
521
+ self.graph.add_node(file_path, data=file_node, level=1)
522
+ for chunk_id in chunk_ids:
523
+ self.graph.add_edge(file_path, chunk_id, relation='contains')
524
+
525
+ file_info[file_path] = {
526
+ 'declared_entities': file_declared_entities,
527
+ 'called_entities': file_called_entities,
528
+ 'chunk_ids': chunk_ids,
529
+ 'parts': parts,
530
+ }
531
+ self.logger.info(f"File node {file_path} added to graph with {len(chunk_ids)} chunks.")
532
+
533
+ self.logger.info("All file nodes have been created.")
534
+ return file_info
535
+
536
+ def _create_directory_nodes(self, file_info):
537
+ self.logger.info("Starting directory node creation.")
538
+ """
539
+ For each directory, aggregate file information and create DirectoryNode objects.
540
+
541
+ Args:
542
+ file_info (dict): Mapping file_path -> file info dict.
543
+
544
+ Returns:
545
+ dict: Mapping dir_path -> aggregated entity info.
546
+ """
547
+
548
+ def merge_entities(target, source):
549
+ # Merge entity lists, avoiding duplicates by (name, type)
550
+ existing = set((e.get('name'), e.get('type')) for e in target)
551
+ for e in source:
552
+ k = (e.get('name'), e.get('type'))
553
+ if k not in existing:
554
+ target.append(e)
555
+ existing.add(k)
556
+
557
+ def merge_called_entities(target, source):
558
+ # Merge called entity lists, avoiding duplicates
559
+ existing = set(target)
560
+ for e in source:
561
+ if e not in existing:
562
+ target.append(e)
563
+ existing.add(e)
564
+
565
+ dir_agg = {}
566
+ for file_path, info in tqdm.tqdm(file_info.items(), desc="Creating directory nodes"):
567
+ self.logger.info(f"Processing directory nodes for file: {file_path}")
568
+ parts = os.path.normpath(file_path).split(os.sep)
569
+ file_declared_entities = info['declared_entities']
570
+ file_called_entities = info['called_entities']
571
+ current_parent = 'root'
572
+ path_accum = ''
573
+ for part in parts[:-1]: # Skip file itself
574
+ path_accum = os.path.join(path_accum, part) if path_accum else part
575
+ if path_accum not in self.graph:
576
+ self.logger.info(f"Adding new directory node: {path_accum}")
577
+ dir_node = DirectoryNode(id=path_accum, name=part, path=path_accum)
578
+ self.graph.add_node(path_accum, data=dir_node, level=1)
579
+ self.graph.add_edge(current_parent, path_accum, relation='contains')
580
+ if path_accum not in dir_agg:
581
+ dir_agg[path_accum] = {'declared_entities': [], 'called_entities': []}
582
+ merge_entities(dir_agg[path_accum]['declared_entities'], file_declared_entities)
583
+ merge_called_entities(dir_agg[path_accum]['called_entities'], file_called_entities)
584
+ current_parent = path_accum
585
+ # Connect file to its parent directory
586
+ self.graph.add_edge(current_parent, file_path, relation='contains')
587
+ self.logger.info("All directory nodes created.")
588
+ return dir_agg
589
+
590
+ def _aggregate_to_root(self, dir_agg):
591
+ self.logger.info("Aggregating directory information to root node.")
592
+ """
593
+ Aggregate all directory entity information to the root node.
594
+
595
+ Args:
596
+ dir_agg (dict): Mapping dir_path -> aggregated entity info.
597
+ """
598
+
599
+ def merge_entities(target, source):
600
+ # Merge entity lists, avoiding duplicates by (name, type)
601
+ existing = set((e.get('name'), e.get('type')) for e in target)
602
+ for e in source:
603
+ k = (e.get('name'), e.get('type'))
604
+ if k not in existing:
605
+ target.append(e)
606
+ existing.add(k)
607
+
608
+ def merge_called_entities(target, source):
609
+ # Merge called entity lists, avoiding duplicates
610
+ existing = set(target)
611
+ for e in source:
612
+ if e not in existing:
613
+ target.append(e)
614
+ existing.add(e)
615
+
616
+ root_node = Node(id='root', name='root', node_type='root')
617
+ self.graph.add_node('root', data=root_node, level=0)
618
+ root_declared_entities = []
619
+ root_called_entities = []
620
+ for dir_path, agg in tqdm.tqdm(dir_agg.items(), desc="Aggregating to root"):
621
+ node = self.graph.nodes[dir_path]['data']
622
+ if not hasattr(node, 'declared_entities'):
623
+ node.declared_entities = []
624
+ if not hasattr(node, 'called_entities'):
625
+ node.called_entities = []
626
+ merge_entities(node.declared_entities, agg['declared_entities'])
627
+ merge_called_entities(node.called_entities, agg['called_entities'])
628
+ merge_entities(root_declared_entities, agg['declared_entities'])
629
+ merge_called_entities(root_called_entities, agg['called_entities'])
630
+ if not hasattr(root_node, 'declared_entities'):
631
+ root_node.declared_entities = []
632
+ if not hasattr(root_node, 'called_entities'):
633
+ root_node.called_entities = []
634
+ merge_entities(root_node.declared_entities, root_declared_entities)
635
+ merge_called_entities(root_node.called_entities, root_called_entities)
636
+ self.logger.info("Aggregation to root node complete.")
637
+
638
+ def _build_relationships(self):
639
+ self.logger.info("Building relationships between chunk nodes based on entities.")
640
+ """
641
+ Build relationships between chunk nodes and entity nodes based on self.entities.
642
+ For each entity in self.entities:
643
+ 1. Create an EntityNode with entity_name as the id
644
+ 2. Create edges from declaring chunks to entity node (declares relationship)
645
+ 3. Create edges from entity node to calling chunks (called_by relationship)
646
+ 4. Resolve called entity names using aliases for better matching
647
+ """
648
+ from .Node import EntityNode
649
+ edges_created = 0
650
+ entity_nodes_created = 0
651
+
652
+ # Build alias map for quick lookups
653
+ self.logger.info("Building entity alias map for call resolution...")
654
+ alias_map = build_entity_alias_map(self.entities)
655
+ self.logger.info(f"Built alias map with {len(alias_map)} entries")
656
+
657
+ # First pass: Create all entity nodes
658
+ for entity_name, info in tqdm.tqdm(self.entities.items(), desc="Creating entity nodes"):
659
+ # Entity type is stored as a list in 'type' key, get first type or empty string
660
+ entity_types = info.get('type', [])
661
+ entity_type = entity_types[0] if entity_types else ''
662
+ declaring_chunks = info.get('declaring_chunk_ids', [])
663
+ calling_chunks = info.get('calling_chunk_ids', [])
664
+ aliases = info.get('aliases', [])
665
+
666
+ # Create EntityNode with entity_name as id
667
+ entity_node = EntityNode(
668
+ id=entity_name,
669
+ name=entity_name,
670
+ entity_type=entity_type,
671
+ declaring_chunk_ids=declaring_chunks,
672
+ calling_chunk_ids=calling_chunks,
673
+ aliases=aliases
674
+ )
675
+
676
+ # Add entity node to graph
677
+ self.graph.add_node(entity_name, data=entity_node, level=3)
678
+ entity_nodes_created += 1
679
+
680
+ # Log aliases for debugging
681
+ if aliases:
682
+ self.logger.debug(f"Created EntityNode '{entity_name}' with aliases: {aliases}")
683
+
684
+ # Create edges from declaring chunks to entity node
685
+ for declarer_id in declaring_chunks:
686
+ if declarer_id in self.graph:
687
+ self.graph.add_edge(declarer_id, entity_name, relation='declares')
688
+ edges_created += 1
689
+
690
+ # Create edges from entity node to calling chunks
691
+ for caller_id in calling_chunks:
692
+ if caller_id in self.graph and caller_id not in declaring_chunks:
693
+ self.graph.add_edge(entity_name, caller_id, relation='called_by')
694
+ edges_created += 1
695
+
696
+ # Second pass: Resolve unmatched entity calls using alias matching
697
+ self.logger.info("Resolving entity calls using alias matching...")
698
+ resolved_calls = 0
699
+
700
+ for entity_name, info in tqdm.tqdm(self.entities.items(), desc="Resolving entity calls"):
701
+ # Skip entities that already have declarations (they were matched directly)
702
+ if info.get('declaring_chunk_ids'):
703
+ continue
704
+
705
+ # Try to resolve this called entity to a declared entity using aliases
706
+ resolved_name = resolve_entity_call(entity_name, alias_map)
707
+
708
+ if resolved_name and resolved_name != entity_name:
709
+ # Found a match! Update the calling_chunk_ids of the resolved entity
710
+ calling_chunks = info.get('calling_chunk_ids', [])
711
+
712
+ if resolved_name in self.entities:
713
+ for caller_id in calling_chunks:
714
+ if caller_id in self.graph:
715
+ # Add edge from resolved entity to calling chunk
716
+ if not self.graph.has_edge(resolved_name, caller_id):
717
+ self.graph.add_edge(resolved_name, caller_id, relation='called_by')
718
+ edges_created += 1
719
+ resolved_calls += 1
720
+ self.logger.debug(f"Resolved call: '{entity_name}' -> '{resolved_name}' in chunk {caller_id}")
721
+
722
+ self.logger.info(f"_build_relationships: Created {entity_nodes_created} entity nodes, "
723
+ f"{edges_created} edges, and resolved {resolved_calls} entity calls using aliases.")
724
+
725
+ def get_entity_by_alias(self, alias: str) -> Optional[str]:
726
+ """
727
+ Get the canonical entity name for a given alias.
728
+
729
+ Args:
730
+ alias: An alias of an entity (e.g., 'MyClass' or 'module.MyClass')
731
+
732
+ Returns:
733
+ Canonical entity name if found, None otherwise
734
+ """
735
+ alias_map = build_entity_alias_map(self.entities)
736
+ return alias_map.get(alias)
737
+
738
+ def resolve_entity_references(self) -> Dict[str, List[str]]:
739
+ """
740
+ Resolve all entity references in the knowledge graph using aliases.
741
+ Returns a mapping of unresolved entity calls to their potential matches.
742
+
743
+ Returns:
744
+ Dictionary mapping called entity names to list of potential canonical matches
745
+ """
746
+ alias_map = build_entity_alias_map(self.entities)
747
+ resolutions = {}
748
+
749
+ for entity_name, info in self.entities.items():
750
+ # Only look at entities that are called but not declared
751
+ if not info.get('declaring_chunk_ids') and info.get('calling_chunk_ids'):
752
+ resolved = resolve_entity_call(entity_name, alias_map)
753
+ if resolved:
754
+ resolutions[entity_name] = resolved
755
+
756
+ return resolutions
757
+
758
+ def print_tree(self, max_depth=None, start_node_id='root', level=0, prefix=""):
759
+ """
760
+ Print the repository tree structure using the graph with 'contains' edges.
761
+
762
+ Args:
763
+ max_depth (int, optional): Maximum depth to print. None = unlimited.
764
+ start_node_id (str): ID of the node to start from. Default is 'root'.
765
+ level (int): Internal use only (used for recursion).
766
+ prefix (str): Internal use only (used for formatting output).
767
+ """
768
+ if max_depth is not None and level > max_depth:
769
+ self.logger.debug(f"Max depth {max_depth} reached at node {start_node_id}.")
770
+ return
771
+
772
+ if start_node_id not in self.graph:
773
+ self.logger.warning(f"Start node '{start_node_id}' not found in graph.")
774
+ return
775
+
776
+ try:
777
+ node_data = self[start_node_id]
778
+ except KeyError as e:
779
+ self.logger.error(f"KeyError when accessing node {start_node_id}: {e}")
780
+ self.logger.error(f"Available node attributes: {list(self.graph.nodes[start_node_id].keys())}")
781
+ # Use a fallback approach if 'data' is missing
782
+ if 'data' not in self.graph.nodes[start_node_id]:
783
+ self.logger.warning(f"Node {start_node_id} has no 'data' attribute, using node itself")
784
+ # Create a fallback node if 'data' is missing
785
+ if start_node_id == 'root':
786
+ # Create a default root node
787
+ node_data = Node(id='root', name='root', node_type='root')
788
+ # Update the graph node with the fallback data
789
+ self.graph.nodes[start_node_id]['data'] = node_data
790
+ else:
791
+ # Try to infer node type from ID or structure
792
+ name = start_node_id.split('/')[-1] if '/' in start_node_id else start_node_id
793
+ if '_' in start_node_id and start_node_id.split('_')[-1].isdigit():
794
+ # Looks like a chunk ID
795
+ node_data = ChunkNode(id=start_node_id, name=name, node_type='chunk')
796
+ elif '.' in name:
797
+ # Looks like a file
798
+ node_data = FileNode(id=start_node_id, name=name, node_type='file', path=start_node_id)
799
+ else:
800
+ # Fallback to directory or generic node
801
+ node_data = DirectoryNode(id=start_node_id, name=name, node_type='directory',
802
+ path=start_node_id)
803
+ # Update the graph node with the fallback data
804
+ self.graph.nodes[start_node_id]['data'] = node_data
805
+ return
806
+
807
+ # Choose icon based on node type
808
+ if node_data.node_type == 'file':
809
+ node_symbol = "πŸ“„"
810
+ elif node_data.node_type == 'chunk':
811
+ node_symbol = "πŸ“"
812
+ elif node_data.node_type == 'root':
813
+ node_symbol = "πŸ“"
814
+ elif node_data.node_type == 'directory':
815
+ node_symbol = "πŸ“‚"
816
+ else:
817
+ node_symbol = "πŸ“¦"
818
+
819
+ if level == 0:
820
+ print(f"{node_symbol} {node_data.name} ({node_data.node_type})")
821
+ else:
822
+ print(f"{prefix}└── {node_symbol} {node_data.name} ({node_data.node_type})")
823
+
824
+ # Get children via 'contains' edges
825
+ children = [
826
+ child for child in self.graph.successors(start_node_id)
827
+ if self.graph.edges[start_node_id, child].get('relation') == 'contains'
828
+ ]
829
+
830
+ child_count = len(children)
831
+ for i, child_id in enumerate(children):
832
+ is_last = i == child_count - 1
833
+ new_prefix = prefix + (" " if is_last else "β”‚ ")
834
+ self.print_tree(max_depth, start_node_id=child_id, level=level + 1, prefix=new_prefix)
835
+
836
+ def to_dict(self):
837
+ self.logger.info("Serializing graph to dictionary.")
838
+ from .Node import EntityNode
839
+ graph_data = {
840
+ 'nodes': [],
841
+ 'edges': []
842
+ }
843
+
844
+ for node_id, node_attrs in tqdm.tqdm(self.graph.nodes(data=True), desc="Serializing nodes"):
845
+ if 'data' not in node_attrs:
846
+ self.logger.warning(f"Node {node_id} has no 'data' attribute, skipping in serialization")
847
+ continue
848
+
849
+ node = node_attrs['data']
850
+ node_dict = {
851
+ 'id': node.id or node_id,
852
+ 'class': node.__class__.__name__,
853
+ 'data': {
854
+ 'id': node.id or node_id,
855
+ 'name': node.name,
856
+ 'node_type': node.node_type,
857
+ 'description': getattr(node, 'description', ''),
858
+ 'declared_entities': list(getattr(node, 'declared_entities', [])),
859
+ 'called_entities': list(getattr(node, 'called_entities', [])),
860
+ }
861
+ }
862
+
863
+ # FileNode-specific
864
+ if isinstance(node, FileNode):
865
+ node_dict['data']['path'] = node.path
866
+ node_dict['data']['content'] = node.content
867
+ node_dict['data']['language'] = getattr(node, 'language', '')
868
+
869
+ # ChunkNode-specific
870
+ if isinstance(node, ChunkNode):
871
+ node_dict['data']['order_in_file'] = getattr(node, 'order_in_file', 0)
872
+ node_dict['data']['embedding'] = getattr(node, 'embedding', None)
873
+
874
+ # EntityNode-specific
875
+ if isinstance(node, EntityNode):
876
+ node_dict['data']['entity_type'] = getattr(node, 'entity_type', '')
877
+ node_dict['data']['declaring_chunk_ids'] = list(getattr(node, 'declaring_chunk_ids', []))
878
+ node_dict['data']['calling_chunk_ids'] = list(getattr(node, 'calling_chunk_ids', []))
879
+ node_dict['data']['aliases'] = list(getattr(node, 'aliases', []))
880
+
881
+ graph_data['nodes'].append(node_dict)
882
+
883
+ for u, v, attrs in tqdm.tqdm(self.graph.edges(data=True), desc="Serializing edges"):
884
+ edge_data = {
885
+ 'source': u,
886
+ 'target': v,
887
+ 'relation': attrs.get('relation', '')
888
+ }
889
+ if 'entities' in attrs:
890
+ edge_data['entities'] = list(attrs['entities'])
891
+ graph_data['edges'].append(edge_data)
892
+
893
+ self.logger.info("Serialization complete.")
894
+ return graph_data
895
+
896
+ @classmethod
897
+ def from_dict(cls, data_dict, index_nodes: bool = True, use_embed: bool = True,
898
+ model_service_kwargs: Optional[dict] = None, code_index_kwargs: Optional[dict] = None):
899
+ # ...existing code...
900
+ instance = cls.__new__(cls) # bypass __init__
901
+ instance._initialize(model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs)
902
+
903
+ instance.logger.info("Deserializing graph from dictionary.")
904
+
905
+
906
+ node_classes = {
907
+ 'Node': Node,
908
+ 'FileNode': FileNode,
909
+ 'ChunkNode': ChunkNode,
910
+ 'DirectoryNode': DirectoryNode,
911
+ 'EntityNode': EntityNode,
912
+ }
913
+
914
+ # Create a root node if not present in the data
915
+ root_found = any(node_data['id'] == 'root' for node_data in data_dict['nodes'])
916
+ if not root_found:
917
+ instance.logger.warning("Root node not found in the data, creating one")
918
+ root_node = Node(id='root', name='root', node_type='root')
919
+ instance.graph.add_node('root', data=root_node, level=0)
920
+
921
+ # --- Rebuild nodes ---
922
+ for node_data in tqdm.tqdm(data_dict['nodes'], desc="Rebuilding nodes"):
923
+ cls_name = node_data['class']
924
+ node_cls = node_classes.get(cls_name, Node)
925
+ kwargs = node_data['data']
926
+
927
+ # Ensure ID is properly set
928
+ if not kwargs.get('id'):
929
+ kwargs['id'] = node_data['id']
930
+
931
+ # Always use lists for declared_entities and called_entities
932
+ kwargs['declared_entities'] = list(kwargs.get('declared_entities', []))
933
+ kwargs['called_entities'] = list(kwargs.get('called_entities', []))
934
+
935
+ # FileNode-specific
936
+ if node_cls in (FileNode, ChunkNode):
937
+ kwargs.setdefault('path', '')
938
+ kwargs.setdefault('content', '')
939
+ kwargs.setdefault('language', '')
940
+ if node_cls == ChunkNode:
941
+ kwargs.setdefault('order_in_file', 0)
942
+ kwargs.setdefault('embedding', [])
943
+ # EntityNode-specific
944
+ if node_cls == EntityNode:
945
+ kwargs.setdefault('entity_type', '')
946
+ kwargs.setdefault('declaring_chunk_ids', [])
947
+ kwargs.setdefault('calling_chunk_ids', [])
948
+ kwargs.setdefault('aliases', [])
949
+
950
+ node_instance = node_cls(**kwargs)
951
+ instance.graph.add_node(node_data['id'], data=node_instance, level=instance._infer_level(node_instance))
952
+
953
+ # --- Rebuild edges ---
954
+ for edge in tqdm.tqdm(data_dict['edges'], desc="Rebuilding edges"):
955
+ source = edge['source']
956
+ target = edge['target']
957
+ if source in instance.graph and target in instance.graph:
958
+ edge_kwargs = {'relation': edge.get('relation', '')}
959
+ if 'entities' in edge:
960
+ edge_kwargs['entities'] = list(edge['entities'])
961
+ instance.graph.add_edge(source, target, **edge_kwargs)
962
+ else:
963
+ instance.logger.warning(f"Cannot add edge {source} -> {target}, nodes don't exist")
964
+
965
+ # --- Rebuild instance.entities ---
966
+ instance.entities = {}
967
+ for node_id, node_attrs in tqdm.tqdm(instance.graph.nodes(data=True), desc="Rebuilding entities"):
968
+ node = node_attrs['data']
969
+ declared_entities = getattr(node, 'declared_entities', [])
970
+ called_entities = getattr(node, 'called_entities', [])
971
+ for entity in declared_entities:
972
+ if isinstance(entity, dict):
973
+ name = entity.get('name')
974
+ else:
975
+ name = entity
976
+ if not name:
977
+ continue
978
+ if name not in instance.entities:
979
+ instance.entities[name] = {
980
+ "declaring_chunk_ids": [],
981
+ "calling_chunk_ids": [],
982
+ "type": [],
983
+ "dtype": None
984
+ }
985
+ # Only add node_id if it is a ChunkNode
986
+ if node_id not in instance.entities[name]["declaring_chunk_ids"]:
987
+ if node_id in instance.graph and isinstance(instance.graph.nodes[node_id]["data"], ChunkNode):
988
+ instance.entities[name]["declaring_chunk_ids"].append(node_id)
989
+ if isinstance(entity, dict):
990
+ entity_type = entity.get("type")
991
+ if entity_type and entity_type not in instance.entities[name]["type"]:
992
+ instance.entities[name]["type"].append(entity_type)
993
+ dtype = entity.get("dtype")
994
+ if dtype:
995
+ instance.entities[name]["dtype"] = dtype
996
+ for called_name in called_entities:
997
+ if not called_name:
998
+ continue
999
+ if called_name not in instance.entities:
1000
+ instance.entities[called_name] = {
1001
+ "declaring_chunk_ids": [],
1002
+ "calling_chunk_ids": [],
1003
+ "type": [],
1004
+ "dtype": None
1005
+ }
1006
+ if node_id not in instance.entities[called_name]["calling_chunk_ids"]:
1007
+ if node_id in instance.graph and isinstance(instance.graph.nodes[node_id]["data"], ChunkNode):
1008
+ instance.entities[called_name]["calling_chunk_ids"].append(node_id)
1009
+
1010
+ if index_nodes:
1011
+ instance.logger.info("Building code index after deserialization.")
1012
+ # Merge use_embed with code_index_kwargs, avoiding duplicate keyword arguments
1013
+ code_idx_kwargs = code_index_kwargs or {}
1014
+ if 'use_embed' not in code_idx_kwargs:
1015
+ code_idx_kwargs['use_embed'] = use_embed
1016
+ instance.code_index = CodeIndex(list(instance), model_service=instance.model_service, **code_idx_kwargs)
1017
+
1018
+ instance.logger.info("Deserialization complete.")
1019
+ return instance
1020
+
1021
+ def _infer_level(self, node):
1022
+ """Infer the level of a node based on its type"""
1023
+ if node.node_type == 'root':
1024
+ return 0
1025
+ elif node.node_type in ('file', 'directory'):
1026
+ return 1
1027
+ elif node.node_type == 'chunk':
1028
+ return 2
1029
+ return 1 # Default level
1030
+
1031
+ def save_graph_to_file(self, filepath: str):
1032
+ self.logger.info(f"Saving graph to file: {filepath}")
1033
+ with open(filepath, 'w') as f:
1034
+ json.dump(self.to_dict(), f, indent=2)
1035
+ self.logger.info("Graph saved successfully.")
1036
+
1037
+ @classmethod
1038
+ def load_graph_from_file(cls, filepath: str, index_nodes=True, use_embed: bool = True,
1039
+ model_service_kwargs: Optional[dict] = None, code_index_kwargs: Optional[dict] = None):
1040
+ if model_service_kwargs is None:
1041
+ model_service_kwargs = {}
1042
+ with open(filepath, 'r') as f:
1043
+ data = json.load(f)
1044
+ logging.getLogger(LOGGER_NAME).info(f"Loaded graph data from file: {filepath}")
1045
+ return cls.from_dict(data, use_embed=use_embed, index_nodes=index_nodes,
1046
+ model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs)
1047
+
1048
+ def to_hf_dataset(
1049
+ self,
1050
+ repo_id: str,
1051
+ save_embeddings: bool = True,
1052
+ private: bool = False,
1053
+ token: Optional[str] = None,
1054
+ commit_message: Optional[str] = None,
1055
+ ):
1056
+ """
1057
+ Save the knowledge graph to a HuggingFace dataset on the Hub.
1058
+
1059
+ The graph is serialized into two splits:
1060
+ - 'nodes': Contains all node data
1061
+ - 'edges': Contains all edge relationships
1062
+
1063
+ Args:
1064
+ repo_id (str): The HuggingFace dataset repository ID (e.g., 'username/dataset-name')
1065
+ save_embeddings (bool): If True, saves embedding vectors for chunk nodes.
1066
+ If False, embeddings are excluded to reduce dataset size.
1067
+ private (bool): Whether the dataset should be private. Defaults to False.
1068
+ token (str, optional): HuggingFace API token. If not provided, uses the token
1069
+ from huggingface_hub login or HF_TOKEN environment variable.
1070
+ commit_message (str, optional): Custom commit message for the upload.
1071
+
1072
+ Returns:
1073
+ str: URL of the uploaded dataset
1074
+ """
1075
+ try:
1076
+ from datasets import Dataset, DatasetDict
1077
+ from huggingface_hub import HfApi
1078
+ except ImportError:
1079
+ raise ImportError(
1080
+ "huggingface_hub and datasets are required for HuggingFace integration. "
1081
+ "Install them with: pip install huggingface_hub datasets"
1082
+ )
1083
+
1084
+ self.logger.info(f"Preparing to save knowledge graph to HuggingFace dataset: {repo_id}")
1085
+ self.logger.info(f"save_embeddings={save_embeddings}")
1086
+
1087
+ # Serialize nodes
1088
+ nodes_data = []
1089
+ for node_id, node_attrs in tqdm.tqdm(self.graph.nodes(data=True), desc="Serializing nodes for HF dataset"):
1090
+ if 'data' not in node_attrs:
1091
+ self.logger.warning(f"Node {node_id} has no 'data' attribute, skipping")
1092
+ continue
1093
+
1094
+ node = node_attrs['data']
1095
+ node_record = {
1096
+ 'node_id': node.id or node_id,
1097
+ 'node_class': node.__class__.__name__,
1098
+ 'name': node.name,
1099
+ 'node_type': node.node_type,
1100
+ 'description': getattr(node, 'description', '') or '',
1101
+ 'declared_entities': json.dumps(list(getattr(node, 'declared_entities', []))),
1102
+ 'called_entities': json.dumps(list(getattr(node, 'called_entities', []))),
1103
+ }
1104
+
1105
+ # FileNode-specific fields
1106
+ if isinstance(node, FileNode):
1107
+ node_record['path'] = node.path
1108
+ node_record['content'] = node.content
1109
+ node_record['language'] = getattr(node, 'language', '')
1110
+ else:
1111
+ node_record['path'] = ''
1112
+ node_record['content'] = ''
1113
+ node_record['language'] = ''
1114
+
1115
+ # ChunkNode-specific fields
1116
+ if isinstance(node, ChunkNode):
1117
+ node_record['order_in_file'] = getattr(node, 'order_in_file', 0)
1118
+ if save_embeddings:
1119
+ embedding = getattr(node, 'embedding', None)
1120
+ node_record['embedding'] = json.dumps(embedding if embedding is not None else [])
1121
+ else:
1122
+ node_record['embedding'] = json.dumps([])
1123
+ else:
1124
+ node_record['order_in_file'] = -1
1125
+ node_record['embedding'] = json.dumps([])
1126
+
1127
+ # EntityNode-specific fields
1128
+ if isinstance(node, EntityNode):
1129
+ node_record['entity_type'] = getattr(node, 'entity_type', '')
1130
+ node_record['declaring_chunk_ids'] = json.dumps(list(getattr(node, 'declaring_chunk_ids', [])))
1131
+ node_record['calling_chunk_ids'] = json.dumps(list(getattr(node, 'calling_chunk_ids', [])))
1132
+ node_record['aliases'] = json.dumps(list(getattr(node, 'aliases', [])))
1133
+ else:
1134
+ node_record['entity_type'] = ''
1135
+ node_record['declaring_chunk_ids'] = json.dumps([])
1136
+ node_record['calling_chunk_ids'] = json.dumps([])
1137
+ node_record['aliases'] = json.dumps([])
1138
+
1139
+ nodes_data.append(node_record)
1140
+
1141
+ # Serialize edges
1142
+ edges_data = []
1143
+ for source, target, attrs in tqdm.tqdm(self.graph.edges(data=True), desc="Serializing edges for HF dataset"):
1144
+ edge_record = {
1145
+ 'source': source,
1146
+ 'target': target,
1147
+ 'relation': attrs.get('relation', ''),
1148
+ 'entities': json.dumps(list(attrs.get('entities', []))) if 'entities' in attrs else json.dumps([])
1149
+ }
1150
+ edges_data.append(edge_record)
1151
+
1152
+ # Create datasets
1153
+ nodes_dataset = Dataset.from_list(nodes_data)
1154
+ edges_dataset = Dataset.from_list(edges_data)
1155
+
1156
+ self.logger.info(f"Created dataset with {len(nodes_data)} nodes and {len(edges_data)} edges")
1157
+
1158
+ # Push to Hub - nodes and edges are pushed separately as different configs
1159
+ # because they have different schemas
1160
+ if commit_message is None:
1161
+ base_commit_message = f"Upload knowledge graph ({len(nodes_data)} nodes, {len(edges_data)} edges)"
1162
+ if not save_embeddings:
1163
+ base_commit_message += " [embeddings excluded]"
1164
+ else:
1165
+ base_commit_message = commit_message
1166
+
1167
+ self.logger.info(f"Pushing nodes dataset to HuggingFace Hub: {repo_id}")
1168
+ nodes_dataset.push_to_hub(
1169
+ repo_id=repo_id,
1170
+ config_name="nodes",
1171
+ private=private,
1172
+ token=token,
1173
+ commit_message=f"{base_commit_message} - nodes"
1174
+ )
1175
+
1176
+ self.logger.info(f"Pushing edges dataset to HuggingFace Hub: {repo_id}")
1177
+ edges_dataset.push_to_hub(
1178
+ repo_id=repo_id,
1179
+ config_name="edges",
1180
+ private=private,
1181
+ token=token,
1182
+ commit_message=f"{base_commit_message} - edges"
1183
+ )
1184
+
1185
+ url = f"https://huggingface.co/datasets/{repo_id}"
1186
+ self.logger.info(f"Dataset successfully uploaded to: {url}")
1187
+ return url
1188
+
1189
+ @classmethod
1190
+ def from_hf_dataset(
1191
+ cls,
1192
+ repo_id: str,
1193
+ index_nodes: bool = True,
1194
+ use_embed: bool = True,
1195
+ model_service_kwargs: Optional[dict] = None,
1196
+ code_index_kwargs: Optional[dict] = None,
1197
+ token: Optional[str] = None,
1198
+ revision: Optional[str] = None,
1199
+ ):
1200
+ """
1201
+ Load a knowledge graph from a HuggingFace dataset on the Hub.
1202
+
1203
+ Args:
1204
+ repo_id (str): The HuggingFace dataset repository ID (e.g., 'username/dataset-name')
1205
+ index_nodes (bool): Whether to build a code index after loading. Defaults to True.
1206
+ use_embed (bool): Whether to use existing embeddings from the dataset. Defaults to True.
1207
+ model_service_kwargs (dict, optional): Arguments for the model service.
1208
+ code_index_kwargs (dict, optional): Arguments for the code index.
1209
+ token (str, optional): HuggingFace API token for private datasets.
1210
+ revision (str, optional): Git revision (branch, tag, or commit) to load from.
1211
+
1212
+ Returns:
1213
+ RepoKnowledgeGraph: The loaded knowledge graph instance.
1214
+ """
1215
+ try:
1216
+ from datasets import load_dataset
1217
+ except ImportError:
1218
+ raise ImportError(
1219
+ "datasets library is required for HuggingFace integration. "
1220
+ "Install it with: pip install datasets"
1221
+ )
1222
+
1223
+ if model_service_kwargs is None:
1224
+ model_service_kwargs = {}
1225
+
1226
+ logger = logging.getLogger(LOGGER_NAME)
1227
+ logger.info(f"Loading knowledge graph from HuggingFace dataset: {repo_id}")
1228
+
1229
+ # Load dataset from Hub - nodes and edges are stored as separate configs
1230
+ logger.info("Loading nodes config...")
1231
+ nodes_dataset = load_dataset(repo_id, name="nodes", token=token, revision=revision)
1232
+ logger.info("Loading edges config...")
1233
+ edges_dataset = load_dataset(repo_id, name="edges", token=token, revision=revision)
1234
+
1235
+ # Get the train split (default split when pushing with config_name)
1236
+ nodes_data = nodes_dataset['train']
1237
+ edges_data = edges_dataset['train']
1238
+
1239
+ logger.info(f"Loaded {len(nodes_data)} nodes and {len(edges_data)} edges from dataset")
1240
+
1241
+ # Convert to the dict format expected by from_dict
1242
+ graph_data = {
1243
+ 'nodes': [],
1244
+ 'edges': []
1245
+ }
1246
+
1247
+ # Reconstruct nodes
1248
+ for record in tqdm.tqdm(nodes_data, desc="Reconstructing nodes from HF dataset"):
1249
+ node_dict = {
1250
+ 'id': record['node_id'],
1251
+ 'class': record['node_class'],
1252
+ 'data': {
1253
+ 'id': record['node_id'],
1254
+ 'name': record['name'],
1255
+ 'node_type': record['node_type'],
1256
+ 'description': record['description'],
1257
+ 'declared_entities': json.loads(record['declared_entities']),
1258
+ 'called_entities': json.loads(record['called_entities']),
1259
+ }
1260
+ }
1261
+
1262
+ # FileNode-specific fields
1263
+ if record['node_class'] in ('FileNode', 'ChunkNode'):
1264
+ node_dict['data']['path'] = record['path']
1265
+ node_dict['data']['content'] = record['content']
1266
+ node_dict['data']['language'] = record['language']
1267
+
1268
+ # ChunkNode-specific fields
1269
+ if record['node_class'] == 'ChunkNode':
1270
+ node_dict['data']['order_in_file'] = record['order_in_file']
1271
+ embedding = json.loads(record['embedding'])
1272
+ # Only use embedding if use_embed is True and embedding is non-empty
1273
+ if use_embed and embedding:
1274
+ node_dict['data']['embedding'] = embedding
1275
+ else:
1276
+ node_dict['data']['embedding'] = []
1277
+
1278
+ # EntityNode-specific fields
1279
+ if record['node_class'] == 'EntityNode':
1280
+ node_dict['data']['entity_type'] = record['entity_type']
1281
+ node_dict['data']['declaring_chunk_ids'] = json.loads(record['declaring_chunk_ids'])
1282
+ node_dict['data']['calling_chunk_ids'] = json.loads(record['calling_chunk_ids'])
1283
+ node_dict['data']['aliases'] = json.loads(record['aliases'])
1284
+
1285
+ graph_data['nodes'].append(node_dict)
1286
+
1287
+ # Reconstruct edges
1288
+ for record in tqdm.tqdm(edges_data, desc="Reconstructing edges from HF dataset"):
1289
+ edge_dict = {
1290
+ 'source': record['source'],
1291
+ 'target': record['target'],
1292
+ 'relation': record['relation'],
1293
+ }
1294
+ entities = json.loads(record['entities'])
1295
+ if entities:
1296
+ edge_dict['entities'] = entities
1297
+
1298
+ graph_data['edges'].append(edge_dict)
1299
+
1300
+ logger.info("Dataset reconstruction complete, building graph...")
1301
+
1302
+ # Use from_dict to build the graph
1303
+ return cls.from_dict(
1304
+ graph_data,
1305
+ index_nodes=index_nodes,
1306
+ use_embed=use_embed,
1307
+ model_service_kwargs=model_service_kwargs,
1308
+ code_index_kwargs=code_index_kwargs
1309
+ )
1310
+
1311
+ def get_neighbors(self, node_id):
1312
+ self.logger.debug(f"Getting neighbors for node: {node_id}")
1313
+ # Return all nodes that are directly connected to node_id (successors and predecessors) for any edge type
1314
+ neighbors = set()
1315
+ for n in self.graph.successors(node_id):
1316
+ neighbors.add(n)
1317
+ for n in self.graph.predecessors(node_id):
1318
+ neighbors.add(n)
1319
+ # Also include nodes connected by any edge (not just 'contains')
1320
+ for u, v in self.graph.edges(node_id):
1321
+ if u == node_id:
1322
+ neighbors.add(v)
1323
+ else:
1324
+ neighbors.add(u)
1325
+ for u, v in self.graph.in_edges(node_id):
1326
+ if v == node_id:
1327
+ neighbors.add(u)
1328
+ else:
1329
+ neighbors.add(v)
1330
+ return [self.graph.nodes[n]['data'] for n in neighbors if 'data' in self.graph.nodes[n]]
1331
+
1332
+ def get_previous_chunk(self, node_id: str) -> ChunkNode:
1333
+ self.logger.debug(f"Getting previous chunk for node: {node_id}")
1334
+ node = self[node_id]
1335
+ # Check if node is of type ChunkNode
1336
+ if not isinstance(node, ChunkNode):
1337
+ raise Exception(f'Cannot get previous chunk on node of type {type(node)}')
1338
+
1339
+ if node.order_in_file == 0:
1340
+ self.logger.warning(f'Cannot get previous chunk for first node')
1341
+ return None
1342
+
1343
+ file_path = node.path
1344
+ previous_chunk_id = f'{file_path}_{node.order_in_file - 1}'
1345
+
1346
+ if previous_chunk_id not in self.graph:
1347
+ raise Exception(f'Previous chunk {previous_chunk_id} not found in graph')
1348
+
1349
+ previous_chunk = self[previous_chunk_id]
1350
+ return previous_chunk
1351
+
1352
+ def get_next_chunk(self, node_id: str) -> ChunkNode:
1353
+ self.logger.debug(f"Getting next chunk for node: {node_id}")
1354
+ node = self[node_id]
1355
+ # Check if node is of type ChunkNode
1356
+ if not isinstance(node, ChunkNode):
1357
+ raise Exception(f'Cannot get previous chunk on node of type {type(node)}')
1358
+
1359
+ file_path = node.path
1360
+ next_chunk_id = f'{file_path}_{node.order_in_file + 1}'
1361
+
1362
+ if next_chunk_id not in self.graph:
1363
+ self.logger.warning(f'Next chunk {next_chunk_id} not found in graph, it might be the last chunk')
1364
+ return None
1365
+ previous_chunk = self[next_chunk_id]
1366
+ return previous_chunk
1367
+
1368
+ def get_all_chunks(self) -> List[ChunkNode]:
1369
+ self.logger.debug("Getting all chunk nodes.")
1370
+ chunk_nodes = []
1371
+ for node in self:
1372
+ if isinstance(node, ChunkNode):
1373
+ chunk_nodes.append(node)
1374
+ return chunk_nodes
1375
+
1376
+ def get_all_files(self) -> List[FileNode]:
1377
+ self.logger.debug("Getting all file nodes.")
1378
+ """
1379
+ Get all FileNodes in the knowledge graph.
1380
+
1381
+ Returns:
1382
+ List[FileNode]: A list of FileNodes in the graph.
1383
+ """
1384
+ file_nodes = []
1385
+ for node in self.graph.nodes(data=True):
1386
+ node_data = node[1]['data']
1387
+ # Check for exact FileNode type, not ChunkNode (which inherits from FileNode)
1388
+ if isinstance(node_data, FileNode) and node_data.node_type == 'file':
1389
+ file_nodes.append(node_data)
1390
+ return file_nodes
1391
+
1392
+ def get_chunks_of_file(self, file_node_id: str) -> List[ChunkNode]:
1393
+ self.logger.debug(f"Getting chunks for file node: {file_node_id}")
1394
+ """
1395
+ Get all ChunkNodes associated with a specific FileNode.
1396
+
1397
+ Args:
1398
+ file_node (FileNode): The file node to get chunks for.
1399
+
1400
+ Returns:
1401
+ List[ChunkNode]: A list of ChunkNodes associated with the file.
1402
+ """
1403
+ chunk_nodes = []
1404
+ for node in self.graph.neighbors(file_node_id):
1405
+ # Only include ChunkNodes that are connected by a 'contains' edge
1406
+ edge_data = self.graph.get_edge_data(file_node_id, node)
1407
+ node_data = self.graph.nodes[node]['data']
1408
+ if (
1409
+ isinstance(node_data, ChunkNode)
1410
+ and node_data.node_type == 'chunk'
1411
+ and edge_data is not None
1412
+ and edge_data.get('relation') == 'contains'
1413
+ ):
1414
+ chunk_nodes.append(node_data)
1415
+ return chunk_nodes
1416
+
1417
+ def find_path(self, source_id: str, target_id: str, max_depth: int = 5) -> dict:
1418
+ """
1419
+ Find the shortest path between two nodes in the knowledge graph.
1420
+
1421
+ Args:
1422
+ source_id (str): The ID of the source node.
1423
+ target_id (str): The ID of the target node.
1424
+ max_depth (int): Maximum depth to search for a path. Defaults to 5.
1425
+
1426
+ Returns:
1427
+ dict: A dictionary containing path information or error message.
1428
+ """
1429
+ self.logger.debug(f"Finding path from {source_id} to {target_id} with max_depth={max_depth}")
1430
+ g = self.graph
1431
+
1432
+ if source_id not in g:
1433
+ return {"error": f"Source node '{source_id}' not found."}
1434
+ if target_id not in g:
1435
+ return {"error": f"Target node '{target_id}' not found."}
1436
+
1437
+ try:
1438
+ path = nx.shortest_path(g, source=source_id, target=target_id)
1439
+
1440
+ if len(path) - 1 > max_depth:
1441
+ return {
1442
+ "source_id": source_id,
1443
+ "target_id": target_id,
1444
+ "path": [],
1445
+ "length": len(path) - 1,
1446
+ "text": f"Path exists but exceeds max_depth of {max_depth} (actual length: {len(path) - 1})"
1447
+ }
1448
+
1449
+ # Build detailed path information
1450
+ path_details = []
1451
+ for i, node_id in enumerate(path):
1452
+ node = g.nodes[node_id]['data']
1453
+ node_info = {
1454
+ "node_id": node_id,
1455
+ "name": getattr(node, 'name', 'Unknown'),
1456
+ "type": getattr(node, 'node_type', 'Unknown'),
1457
+ "step": i
1458
+ }
1459
+
1460
+ # Add edge information for all but the last node
1461
+ if i < len(path) - 1:
1462
+ next_node_id = path[i + 1]
1463
+ edge_data = g.get_edge_data(node_id, next_node_id)
1464
+ node_info["edge_to_next"] = edge_data.get('relation', 'Unknown') if edge_data else 'Unknown'
1465
+
1466
+ path_details.append(node_info)
1467
+
1468
+ # Format text output
1469
+ text = f"Path from '{source_id}' to '{target_id}' (length: {len(path) - 1}):\n\n"
1470
+ for i, node_info in enumerate(path_details):
1471
+ text += f"{i}. {node_info['name']} ({node_info['type']})\n"
1472
+ text += f" Node ID: {node_info['node_id']}\n"
1473
+ if 'edge_to_next' in node_info:
1474
+ text += f" --[{node_info['edge_to_next']}]--> \n"
1475
+
1476
+ return {
1477
+ "source_id": source_id,
1478
+ "target_id": target_id,
1479
+ "path": path_details,
1480
+ "length": len(path) - 1,
1481
+ "text": text
1482
+ }
1483
+
1484
+ except nx.NetworkXNoPath:
1485
+ return {
1486
+ "source_id": source_id,
1487
+ "target_id": target_id,
1488
+ "path": [],
1489
+ "length": -1,
1490
+ "text": f"No path found between '{source_id}' and '{target_id}'"
1491
+ }
1492
+ except Exception as e:
1493
+ self.logger.error(f"Error finding path: {str(e)}")
1494
+ return {"error": f"Error finding path: {str(e)}"}
1495
+
1496
+ def get_subgraph(self, node_id: str, depth: int = 2, edge_types: Optional[List[str]] = None) -> dict:
1497
+ """
1498
+ Extract a subgraph around a node up to a specified depth.
1499
+
1500
+ Args:
1501
+ node_id (str): The ID of the central node.
1502
+ depth (int): The depth/radius of the subgraph to extract. Defaults to 2.
1503
+ edge_types (Optional[List[str]]): Optional list of edge types to include (e.g., ['calls', 'contains']).
1504
+
1505
+ Returns:
1506
+ dict: A dictionary containing subgraph information or error message.
1507
+ """
1508
+ self.logger.debug(f"Getting subgraph for node {node_id} with depth={depth}, edge_types={edge_types}")
1509
+ g = self.graph
1510
+
1511
+ if node_id not in g:
1512
+ return {"error": f"Node '{node_id}' not found."}
1513
+
1514
+ # Collect nodes within specified depth
1515
+ nodes_at_depth = {node_id}
1516
+ all_nodes = {node_id}
1517
+
1518
+ for d in range(depth):
1519
+ next_level = set()
1520
+ for n in nodes_at_depth:
1521
+ # Get all neighbors (both incoming and outgoing)
1522
+ for neighbor in g.successors(n):
1523
+ if edge_types is None:
1524
+ next_level.add(neighbor)
1525
+ else:
1526
+ edge_data = g.get_edge_data(n, neighbor)
1527
+ if edge_data and edge_data.get('relation') in edge_types:
1528
+ next_level.add(neighbor)
1529
+
1530
+ for neighbor in g.predecessors(n):
1531
+ if edge_types is None:
1532
+ next_level.add(neighbor)
1533
+ else:
1534
+ edge_data = g.get_edge_data(neighbor, n)
1535
+ if edge_data and edge_data.get('relation') in edge_types:
1536
+ next_level.add(neighbor)
1537
+
1538
+ nodes_at_depth = next_level - all_nodes
1539
+ all_nodes.update(next_level)
1540
+
1541
+ # Extract subgraph
1542
+ subgraph = g.subgraph(all_nodes).copy()
1543
+
1544
+ # Build node details
1545
+ nodes = []
1546
+ for n in subgraph.nodes():
1547
+ node = subgraph.nodes[n]['data']
1548
+ nodes.append({
1549
+ "node_id": n,
1550
+ "name": getattr(node, 'name', 'Unknown'),
1551
+ "type": getattr(node, 'node_type', 'Unknown')
1552
+ })
1553
+
1554
+ # Build edge details
1555
+ edges = []
1556
+ for source, target, data in subgraph.edges(data=True):
1557
+ edges.append({
1558
+ "source": source,
1559
+ "target": target,
1560
+ "relation": data.get('relation', 'Unknown')
1561
+ })
1562
+
1563
+ # Format text output
1564
+ text = f"Subgraph around '{node_id}' (depth: {depth}):\n"
1565
+ if edge_types:
1566
+ text += f"Edge types filter: {', '.join(edge_types)}\n"
1567
+ text += f"\nNodes: {len(nodes)}\n"
1568
+ text += f"Edges: {len(edges)}\n\n"
1569
+
1570
+ # Group nodes by type
1571
+ nodes_by_type = {}
1572
+ for node in nodes:
1573
+ node_type = node['type']
1574
+ if node_type not in nodes_by_type:
1575
+ nodes_by_type[node_type] = []
1576
+ nodes_by_type[node_type].append(node)
1577
+
1578
+ for node_type, type_nodes in nodes_by_type.items():
1579
+ text += f"{node_type} ({len(type_nodes)}):\n"
1580
+ for node in type_nodes[:5]:
1581
+ text += f" - {node['name']} ({node['node_id']})\n"
1582
+ if len(type_nodes) > 5:
1583
+ text += f" ... and {len(type_nodes) - 5} more\n"
1584
+ text += "\n"
1585
+
1586
+ # Show edge statistics
1587
+ edge_by_relation = {}
1588
+ for edge in edges:
1589
+ relation = edge['relation']
1590
+ edge_by_relation[relation] = edge_by_relation.get(relation, 0) + 1
1591
+
1592
+ if edge_by_relation:
1593
+ text += "Edge types:\n"
1594
+ for relation, count in edge_by_relation.items():
1595
+ text += f" - {relation}: {count}\n"
1596
+
1597
+ return {
1598
+ "center_node_id": node_id,
1599
+ "depth": depth,
1600
+ "edge_types_filter": edge_types,
1601
+ "node_count": len(nodes),
1602
+ "edge_count": len(edges),
1603
+ "nodes": nodes,
1604
+ "edges": edges,
1605
+ "nodes_by_type": nodes_by_type,
1606
+ "edge_by_relation": edge_by_relation,
1607
+ "text": text
1608
+ }
RepoKnowledgeGraphLib/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """
2
+ RepoKnowledgeGraphLib - Knowledge Graph Library for Code Repositories
3
+
4
+ This library provides tools for creating and querying knowledge graphs from code repositories.
5
+ """
RepoKnowledgeGraphLib/utils/__init__.py ADDED
File without changes
RepoKnowledgeGraphLib/utils/chunk_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..Node import ChunkNode
2
+ from typing import List, Dict
3
+
4
+ def dict_to_chunknode(d: dict) -> ChunkNode:
5
+ """
6
+ Converts a dictionary to a ChunkNode instance.
7
+ """
8
+ return ChunkNode(**d)
9
+
10
+ def extract_filename_from_chunk(chunk:ChunkNode) -> str:
11
+ """
12
+ Extracts the file name from a chunk.
13
+
14
+ Args:
15
+ chunk (str): The chunk from which to extract the file name.
16
+
17
+ Returns:
18
+ str: The extracted file name.
19
+ """
20
+ if isinstance(chunk, dict):
21
+ chunk = dict_to_chunknode(chunk)
22
+ return '_'.join(chunk.id.split('_')[:-1])
23
+
24
+
25
+ def order_chunks_by_order_in_file(chunks:List[ChunkNode]) -> list:
26
+ """
27
+ Orders a list of chunks by their order in the file.
28
+
29
+ Args:
30
+ chunks (list): The list of chunks to order.
31
+
32
+ Returns:
33
+ list: The ordered list of chunks.
34
+ """
35
+ # Convert dicts to ChunkNode if needed
36
+ chunks = [dict_to_chunknode(c) if isinstance(c, dict) else c for c in chunks]
37
+ return sorted(chunks, key=lambda x: int(x.order_in_file))
38
+
39
+ def organize_chunks_by_file_name(chunks: List[ChunkNode]) -> Dict[str, List[ChunkNode]]:
40
+ """
41
+ Organizes a list of chunks by their file names.
42
+
43
+ Args:
44
+ chunks (list): The list of chunks to organize.
45
+
46
+ Returns:
47
+ dict: A dictionary mapping file names to lists of chunks.
48
+ """
49
+ # Convert dicts to ChunkNode if needed
50
+ chunks = [dict_to_chunknode(c) if isinstance(c, dict) else c for c in chunks]
51
+ organized_chunks = {}
52
+ for chunk in chunks:
53
+ file_name = extract_filename_from_chunk(chunk)
54
+ if file_name not in organized_chunks:
55
+ organized_chunks[file_name] = []
56
+ organized_chunks[file_name].append(chunk)
57
+ for file_name in organized_chunks:
58
+ organized_chunks[file_name] = order_chunks_by_order_in_file(organized_chunks[file_name])
59
+ return organized_chunks
60
+
61
+ def join_organized_chunks(organized_chunks: Dict[str, List[ChunkNode]]) -> str:
62
+ """
63
+ Joins organized chunks into a single string.
64
+
65
+ Args:
66
+ organized_chunks (dict): The dictionary of organized chunks.
67
+
68
+ Returns:
69
+ str: The joined string of organized chunks.
70
+ """
71
+ joined_chunks_list = []
72
+ separator = "=" * 48
73
+ for filename in organized_chunks:
74
+ joined_chunks_list.append(separator)
75
+ joined_chunks_list.append(f"File: {filename}")
76
+ joined_chunks_list.append(separator)
77
+ # Convert dicts to ChunkNode if needed
78
+ chunks = [dict_to_chunknode(c) if isinstance(c, dict) else c for c in organized_chunks[filename]]
79
+ if len(chunks) == 0:
80
+ continue
81
+ if int(chunks[0].order_in_file) > 0:
82
+ joined_chunks_list.append("\n[...]")
83
+ for i, chunk in enumerate(chunks):
84
+ joined_chunks_list.append(chunk.content)
85
+ if i < len(chunks) - 1:
86
+ if int(chunks[i+1].order_in_file) - int(chunk.order_in_file) > 1:
87
+ joined_chunks_list.append("\n[...]")
88
+ return "\n".join(joined_chunks_list)
RepoKnowledgeGraphLib/utils/data_utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def flatten_list(my_list: list) -> list:
3
+ """
4
+ Args:
5
+ my_list: list composed of lists (of lists of lists...)
6
+
7
+ Returns: flattened list
8
+
9
+ """
10
+ flattened_list = []
11
+ for item in my_list:
12
+ if isinstance(item, list) and len(item) > 0:
13
+ print(item)
14
+ flattened_list += flatten_list(item)
15
+ elif not isinstance(item, list):
16
+ flattened_list.append(item)
17
+
18
+ return flattened_list
RepoKnowledgeGraphLib/utils/logger_utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ import atexit
5
+
6
+ # Global registry to track initialized loggers
7
+ _initialized_loggers = set()
8
+
9
+ # Get log level from environment variable (default to INFO for visibility in docker logs)
10
+ DEFAULT_LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper()
11
+ LOG_TO_FILE = os.getenv('LOG_TO_FILE', 'false').lower() == 'true'
12
+
13
+ def setup_logger(logger_name: str, log_file: str = '',
14
+ level: int = None) -> None:
15
+ """
16
+ :param logger_name: name to give to logger
17
+ :param log_file: file to save log to
18
+ :param level: which base level of importance to set logger to (defaults to LOG_LEVEL env var)
19
+ :return: *None*
20
+ """
21
+ # Check if logger has already been set up
22
+ if logger_name in _initialized_loggers:
23
+ return
24
+
25
+ log = logging.getLogger(logger_name)
26
+
27
+ # Determine log level from parameter, env var, or default
28
+ if level is None:
29
+ level = getattr(logging, DEFAULT_LOG_LEVEL, logging.INFO)
30
+
31
+ formatter = logging.Formatter(
32
+ fmt="%(name)s - %(levelname)s: %(asctime)-15s %(message)s")
33
+
34
+ # Always add stream handler for stdout (docker logs visibility)
35
+ stream_handler = logging.StreamHandler(sys.stdout)
36
+ stream_handler.setFormatter(formatter)
37
+ stream_handler.setLevel(level)
38
+
39
+ log.setLevel(level)
40
+ if not log.hasHandlers():
41
+ log.addHandler(stream_handler)
42
+
43
+ # Optionally add file handler if LOG_TO_FILE is enabled
44
+ if LOG_TO_FILE:
45
+ os.makedirs('logs', exist_ok=True)
46
+ if log_file == '':
47
+ log_file = f"{logger_name}.log"
48
+ log_file_path = os.path.join('logs', log_file)
49
+ file_handler = logging.FileHandler(log_file_path, mode='w')
50
+ file_handler.setFormatter(formatter)
51
+ file_handler.setLevel(level)
52
+ log.addHandler(file_handler)
53
+
54
+ # Prevent log propagation to avoid duplicate logs
55
+ log.propagate = False
56
+
57
+ # Mark this logger as initialized
58
+ _initialized_loggers.add(logger_name)
59
+
60
+ # Register cleanup function to close handlers on exit
61
+ atexit.register(_cleanup_logger, logger_name)
62
+
63
+ def _cleanup_logger(logger_name: str) -> None:
64
+ """
65
+ Clean up logger handlers on program exit.
66
+
67
+ :param logger_name: name of the logger to clean up
68
+ """
69
+ log = logging.getLogger(logger_name)
70
+ handlers = log.handlers[:]
71
+ for handler in handlers:
72
+ handler.close()
73
+ log.removeHandler(handler)
74
+
RepoKnowledgeGraphLib/utils/parsing_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ def read_directory_files_recursively(directory_path: str, skip_dirs:list, skip_pattern: str = None) -> dict:
5
+ """
6
+ Recursively reads all files in a directory and its subdirectories.
7
+ Skips files and directories that match the given regex pattern or are in skip_dirs.
8
+
9
+ Args:
10
+ directory_path (str): The path to start reading files from.
11
+ skip_dirs (list): List of directory names to skip.
12
+ skip_pattern (str, optional): Regex pattern to skip files/directories.
13
+
14
+ Returns:
15
+ dict: A dictionary where keys are relative file paths and values are file contents.
16
+ """
17
+ file_contents = {}
18
+ compiled_pattern = re.compile(skip_pattern) if skip_pattern else None
19
+
20
+ for root, dirs, files in os.walk(directory_path):
21
+ # Skip directories listed in skip_dirs
22
+ dirs[:] = [d for d in dirs if d not in skip_dirs and not (compiled_pattern and compiled_pattern.search(os.path.join(root, d)))]
23
+
24
+ for file in files:
25
+ full_path = os.path.join(root, file)
26
+ relative_path = os.path.relpath(full_path, directory_path)
27
+
28
+ # Skip matching files
29
+ if compiled_pattern and compiled_pattern.search(relative_path):
30
+ continue
31
+
32
+ try:
33
+ with open(full_path, 'r', encoding='utf-8') as f:
34
+ file_contents[relative_path] = f.read()
35
+ except (UnicodeDecodeError, OSError) as e:
36
+ print(f'Failed to read {relative_path}: {e}')
37
+ continue
38
+ #file_contents[relative_path] = f"<<Error reading file: {e}>>"
39
+
40
+ return file_contents
41
+
42
+
43
+
44
+ def get_language_from_filename(file_name:str) -> str:
45
+ file_extension = file_name.split('.')[-1]
46
+ extension_mapping = {
47
+ 'c': 'c',
48
+ 'h': 'c',
49
+ 'cpp': 'c++',
50
+ 'cc': 'c++',
51
+ 'cxx': 'c++',
52
+ 'hpp': 'c++',
53
+ 'hh': 'c++',
54
+ 'hxx': 'c++',
55
+ 'go': 'go',
56
+ 'java': 'java',
57
+ 'py': 'python',
58
+ 'pyc': 'python',
59
+ 'pyw':'python',
60
+ 'js': 'javascript',
61
+ 'mjs': 'javascript',
62
+ 'cjs': 'javascript',
63
+ }
64
+ # Throws error if language not defined
65
+ return extension_mapping.get(file_extension, file_extension)
RepoKnowledgeGraphLib/utils/path_utils.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import shutil
4
+ import zipfile
5
+ import tarfile
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Tuple
8
+
9
+
10
+ def _extract_zip(path: Path) -> str:
11
+ temp_dir = tempfile.mkdtemp()
12
+ with zipfile.ZipFile(path, 'r') as zip_ref:
13
+ zip_ref.extractall(temp_dir)
14
+ return temp_dir
15
+
16
+
17
+ def _extract_tgz(path: Path) -> str:
18
+ temp_dir = tempfile.mkdtemp()
19
+ with tarfile.open(path, 'r:gz') as tar_ref:
20
+ tar_ref.extractall(temp_dir)
21
+ return temp_dir
22
+
23
+
24
+ def prepare_input_path(path: str) -> str:
25
+ """Handles different input types: directories, files, zip or tgz archives."""
26
+ path_obj = Path(path)
27
+ if path_obj.is_dir():
28
+ return str(path_obj)
29
+
30
+ if path_obj.suffix == '.zip':
31
+ return _extract_zip(path_obj)
32
+ elif path_obj.suffix in {'.tgz', '.tar.gz'}:
33
+ return _extract_tgz(path_obj)
34
+ elif path_obj.is_file():
35
+ # Copy single file to a temporary directory
36
+ temp_dir = tempfile.mkdtemp()
37
+ shutil.copy(path_obj, temp_dir)
38
+ return temp_dir
39
+ else:
40
+ raise ValueError(f"Unsupported path type or extension: {path}")
41
+
42
+
43
+ def file_path_to_module_path(file_path: str) -> str:
44
+ """
45
+ Convert a file path to a module path by replacing path separators with dots
46
+ and removing the file extension.
47
+
48
+ Examples:
49
+ path/to/repo/python_script.py -> path.to.repo.python_script
50
+ src/utils/helper.py -> src.utils.helper
51
+ module.py -> module
52
+
53
+ Args:
54
+ file_path: File path string
55
+
56
+ Returns:
57
+ Module path with dots instead of slashes
58
+ """
59
+ # Normalize path separators
60
+ normalized = file_path.replace('\\', '/').replace(os.sep, '/')
61
+
62
+ # Remove file extension
63
+ without_ext = os.path.splitext(normalized)[0]
64
+
65
+ # Replace / with .
66
+ module_path = without_ext.replace('/', '.')
67
+
68
+ return module_path
69
+
70
+
71
+ def generate_entity_aliases(entity_name: str, file_path: str) -> list:
72
+ """
73
+ Generate all possible aliases for an entity based on its name and file path.
74
+
75
+ For example, if a file 'path/to/repo/python_script.py' defines 'Class_1',
76
+ the aliases would be:
77
+ - Class_1 (simple name)
78
+ - path.to.repo.python_script.Class_1 (fully qualified from file path)
79
+
80
+ For C++ namespaced entities like 'math::Calculator':
81
+ - math::Calculator (fully qualified name)
82
+ - Calculator (unqualified name, for use with 'using namespace')
83
+ - math.calculator.math::Calculator (module-based fully qualified)
84
+
85
+ For temporary paths like '.tmp.tmptqky4yk4..pyinstaller.run_astropy_tests.pos':
86
+ - pos (simple name)
87
+ - .run_astropy_tests.pos (progressive path removal)
88
+ - pyinstaller.run_astropy_tests.pos (further removal)
89
+ - .tmp.tmptqky4yk4..pyinstaller.run_astropy_tests.pos (full path)
90
+
91
+ Args:
92
+ entity_name: The name of the entity (e.g., 'Class_1', 'my_function', 'math::Calculator')
93
+ file_path: The file path where the entity is defined
94
+
95
+ Returns:
96
+ List of alias strings
97
+ """
98
+ aliases = []
99
+
100
+ # Always include the simple entity name
101
+ aliases.append(entity_name)
102
+
103
+ # For C++/C-style namespaced entities (using ::), add the unqualified name
104
+ if '::' in entity_name:
105
+ # Extract the unqualified name (last part after ::)
106
+ unqualified_name = entity_name.split('::')[-1]
107
+ if unqualified_name != entity_name:
108
+ aliases.append(unqualified_name)
109
+
110
+ # Generate module-based alias
111
+ module_path = file_path_to_module_path(file_path)
112
+
113
+ # If entity_name already contains scope separators (., ::),
114
+ # it might be a nested entity (e.g., 'MyClass.my_method')
115
+ # In this case, add the module path before the entire qualified name
116
+ fully_qualified = f"{module_path}.{entity_name}"
117
+
118
+ # Generate progressive path aliases by removing temporary/noise components
119
+ # Split the module path into components
120
+ components = module_path.split('.')
121
+
122
+ # Filter out components that look like temporary directories or UUIDs
123
+ def is_temp_component(component: str) -> bool:
124
+ """Check if a path component looks like a temporary directory."""
125
+ if not component:
126
+ return True
127
+ # Check for common temp directory patterns
128
+ if component.startswith('tmp') and len(component) > 3:
129
+ return True
130
+ if component.startswith('.tmp'):
131
+ return True
132
+ # Check for UUID-like patterns (long alphanumeric strings)
133
+ if len(component) > 8 and component.replace('_', '').replace('-', '').isalnum():
134
+ # If it's mostly lowercase and has mix of letters and numbers, likely a temp ID
135
+ if sum(c.islower() for c in component) > len(component) / 2:
136
+ if sum(c.isdigit() for c in component) > 2:
137
+ return True
138
+ return False
139
+
140
+ # Generate aliases by progressively including more path components
141
+ # Start from the rightmost meaningful components and work backwards
142
+ clean_components = []
143
+ for component in components:
144
+ if not is_temp_component(component):
145
+ clean_components.append(component)
146
+
147
+ # Generate aliases with increasing path depth from meaningful components
148
+ if clean_components:
149
+ for i in range(1, len(clean_components) + 1):
150
+ # Take the last i components
151
+ partial_path = '.'.join(clean_components[-i:])
152
+ partial_alias = f".{partial_path}.{entity_name}"
153
+ if partial_alias != entity_name and partial_alias not in aliases:
154
+ aliases.append(partial_alias)
155
+
156
+ # Also add without leading dot for the full clean path
157
+ if i == len(clean_components):
158
+ no_dot_alias = f"{partial_path}.{entity_name}"
159
+ if no_dot_alias != entity_name and no_dot_alias not in aliases:
160
+ aliases.append(no_dot_alias)
161
+
162
+ # Always add the fully qualified path at the end (even if it contains temp components)
163
+ if fully_qualified != entity_name and fully_qualified not in aliases:
164
+ aliases.append(fully_qualified)
165
+
166
+ return aliases
167
+
168
+
169
+ def normalize_include_path(include_path: str) -> str:
170
+ """
171
+ Normalize an include path from #include directive to a module-like path.
172
+
173
+ Examples:
174
+ <vector> -> vector
175
+ <iostream> -> iostream
176
+ "myheader.h" -> myheader
177
+ "utils/helper.h" -> utils.helper
178
+ <boost/algorithm/string.hpp> -> boost.algorithm.string
179
+
180
+ Args:
181
+ include_path: The include path from #include directive
182
+
183
+ Returns:
184
+ Normalized module-like path
185
+ """
186
+ # Remove angle brackets and quotes
187
+ path = include_path.strip('<>"')
188
+
189
+ # Convert to module path
190
+ module_path = file_path_to_module_path(path)
191
+
192
+ return module_path
193
+
194
+
195
+ def build_entity_alias_map(entities: Dict[str, Dict]) -> Dict[str, str]:
196
+ """
197
+ Build a mapping from all entity aliases to their canonical entity names.
198
+ This allows quick lookup when matching called entities to their definitions.
199
+
200
+ Args:
201
+ entities: Dictionary of entity info keyed by canonical entity name
202
+
203
+ Returns:
204
+ Dictionary mapping alias -> canonical entity name
205
+ """
206
+ alias_map = {}
207
+
208
+ for entity_name, info in entities.items():
209
+ # Map the canonical name to itself
210
+ alias_map[entity_name] = entity_name
211
+
212
+ # Map all aliases to the canonical name
213
+ aliases = info.get('aliases', [])
214
+ for alias in aliases:
215
+ if alias and alias not in alias_map:
216
+ alias_map[alias] = entity_name
217
+
218
+ return alias_map
219
+
220
+
221
+ def resolve_entity_call(called_name: str, alias_map: Dict[str, str],
222
+ imports: List[str] = None) -> Optional[str]:
223
+ """
224
+ Resolve a called entity name to its canonical definition using aliases.
225
+
226
+ This handles cases like:
227
+ - Direct call: 'MyClass' -> 'MyClass'
228
+ - Qualified call: 'module.MyClass' -> 'MyClass' (if alias exists)
229
+ - Imported call: 'helper' -> 'utils.helper' (if imported)
230
+ - Simple name to qualified: 'Calculator' -> 'utils::Calculator'
231
+
232
+ Args:
233
+ called_name: The name of the called entity
234
+ alias_map: Mapping from aliases to canonical entity names
235
+ imports: List of import paths (optional, for context)
236
+
237
+ Returns:
238
+ Canonical entity name if found, None otherwise
239
+ """
240
+ # Don't try to resolve empty strings
241
+ if not called_name or not called_name.strip():
242
+ return None
243
+
244
+ # Direct match
245
+ if called_name in alias_map:
246
+ return alias_map[called_name]
247
+
248
+ # Try partial matches if imports are provided
249
+ if imports:
250
+ for import_path in imports:
251
+ # Try combining import path with called name
252
+ qualified = f"{import_path}.{called_name}"
253
+ if qualified in alias_map:
254
+ return alias_map[qualified]
255
+
256
+ # Try with :: separator (C++/Rust style)
257
+ qualified_cpp = f"{import_path}::{called_name}"
258
+ if qualified_cpp in alias_map:
259
+ return alias_map[qualified_cpp]
260
+
261
+ # Try fuzzy matching - look for canonical names that end with the called name
262
+ # This helps match 'Calculator' to 'utils::Calculator' or 'MyClass' to 'module.MyClass'
263
+ simple_name = extract_simple_name(called_name)
264
+ candidates = []
265
+
266
+ for alias, canonical in alias_map.items():
267
+ alias_simple = extract_simple_name(alias)
268
+ # If the simple names match, this could be a match
269
+ if alias_simple == simple_name:
270
+ candidates.append(canonical)
271
+
272
+ # If we found exactly one candidate, return it
273
+ if len(candidates) == 1:
274
+ return candidates[0]
275
+
276
+ # If we have multiple candidates, prefer the shortest qualified name
277
+ # (most likely to be the direct definition rather than an alias)
278
+ if len(candidates) > 1:
279
+ return min(candidates, key=lambda x: len(x))
280
+
281
+ return None
282
+
283
+
284
+ def extract_simple_name(qualified_name: str) -> str:
285
+ """
286
+ Extract the simple name from a qualified name.
287
+
288
+ Examples:
289
+ 'namespace::MyClass' -> 'MyClass'
290
+ 'module.MyClass' -> 'MyClass'
291
+ 'MyClass' -> 'MyClass'
292
+
293
+ Args:
294
+ qualified_name: Fully or partially qualified name
295
+
296
+ Returns:
297
+ Simple name without namespace/module prefix
298
+ """
299
+ # Handle C++ style namespace separator
300
+ if '::' in qualified_name:
301
+ return qualified_name.split('::')[-1]
302
+
303
+ # Handle Python/JS style module separator
304
+ if '.' in qualified_name:
305
+ return qualified_name.split('.')[-1]
306
+
307
+ return qualified_name
308
+