|
|
import logging |
|
|
from tqdm import tqdm |
|
|
import uuid |
|
|
from typing import Literal |
|
|
from abc import ABC, abstractmethod |
|
|
import lancedb |
|
|
import os |
|
|
import numpy as np |
|
|
import weaviate |
|
|
from weaviate.classes.config import Configure, Property, DataType |
|
|
from weaviate.classes.query import MetadataQuery |
|
|
|
|
|
try: |
|
|
LANCEDB_AVAILABLE = True |
|
|
except ImportError: |
|
|
LANCEDB_AVAILABLE = False |
|
|
|
|
|
from .utils.logger_utils import setup_logger |
|
|
|
|
|
LOGGER_NAME = 'CODE_INDEX_LOGGER' |
|
|
STOP_AFTER_ATTEMPT = int(os.getenv("STOP_AFTER_ATTEMPT", 5)) |
|
|
WAIT_BETWEEN_RETRIES = int(os.getenv("WAIT_BETWEEN_RETRIES", 2)) |
|
|
MODEL_ID = os.getenv("MODEL_ID") |
|
|
MAX_TOKENS = int(os.getenv('MAX_TOKENS', 2048)) |
|
|
TEMPERATURE = float(os.getenv('TEMPERATURE', 0.2)) |
|
|
TOP_P = float(os.getenv('TOP_P', 0.95)) |
|
|
FREQUENCY_PENALTY = 0 |
|
|
PRESENCE_PENALTY = 0 |
|
|
STOP = None |
|
|
EMBEDDING_MODEL_URL = os.getenv('EMBEDDING_MODEL_URL') |
|
|
EMBEDDING_MODEL_API_KEY = os.getenv('EMBEDDING_MODEL_API_KEY', "no_need") |
|
|
EMBEDDING_NUMBER_DIMENSIONS = int(os.getenv('EMBEDDING_NUMBER_DIMENSIONS', 1024)) |
|
|
|
|
|
WEAVIATE_HOST = os.getenv('WEAVIATE_HOST', "localhost") |
|
|
WEAVIATE_PORT = int(os.getenv('WEAVIATE_PORT', 8080)) |
|
|
WEAVIATE_GRPC_PORT = int(os.getenv('WEAVIATE_GRPC_PORT', 50051)) |
|
|
ALPHA_SEARCH_VALUE = float(os.getenv('ALPHA_SEARCH_VALUE', 0.8)) |
|
|
LANCEDB_PATH = os.getenv('LANCEDB_PATH', './local_code_index_db') |
|
|
|
|
|
|
|
|
class BaseCodeIndex(ABC): |
|
|
"""Abstract base class for code indexing implementations""" |
|
|
|
|
|
def __init__(self, nodes: list, model_service, index_type: Literal['embedding-only', 'keyword-only', 'hybrid'] = 'hybrid', |
|
|
embedding_batch_size: int = 64, use_embed: bool = True): |
|
|
setup_logger(LOGGER_NAME) |
|
|
self.logger = logging.getLogger(LOGGER_NAME) |
|
|
self.model_service = model_service |
|
|
self.index_type = index_type |
|
|
|
|
|
self.embedding_batch_size = int(os.getenv('EMBEDDING_BATCH_SIZE', embedding_batch_size)) |
|
|
self.use_embed = use_embed |
|
|
self.logger.info(f"CodeIndex initialized with batch_size={self.embedding_batch_size}, index_type={index_type}") |
|
|
|
|
|
@abstractmethod |
|
|
def query(self, query: str, n_results: int=10) -> dict: |
|
|
"""Query the index and return results""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def __del__(self): |
|
|
"""Clean up resources""" |
|
|
pass |
|
|
|
|
|
|
|
|
class WeaviateCodeIndex(BaseCodeIndex): |
|
|
"""Weaviate-based code index implementation""" |
|
|
|
|
|
def __init__(self, nodes: list, model_service, index_type: Literal['embedding-only', 'keyword-only', 'hybrid'] = 'hybrid', |
|
|
embedding_batch_size: int = 20, use_embed: bool = True, |
|
|
host: str = None, port: int = None, grpc_port: int = None): |
|
|
super().__init__(nodes, model_service, index_type, embedding_batch_size, use_embed) |
|
|
|
|
|
|
|
|
weaviate_host = host or WEAVIATE_HOST |
|
|
weaviate_port = port or WEAVIATE_PORT |
|
|
weaviate_grpc_port = grpc_port or WEAVIATE_GRPC_PORT |
|
|
|
|
|
|
|
|
self.weaviate_client = weaviate.connect_to_local( |
|
|
host=weaviate_host, |
|
|
port=weaviate_port, |
|
|
grpc_port=weaviate_grpc_port |
|
|
) |
|
|
|
|
|
|
|
|
self.collection_name = f"CodeChunks_{str(uuid.uuid4()).replace('-', '_')}" |
|
|
|
|
|
|
|
|
|
|
|
self.collection = self.weaviate_client.collections.create( |
|
|
name=self.collection_name, |
|
|
properties=[ |
|
|
Property(name="node_id", data_type=DataType.TEXT), |
|
|
Property(name="name", data_type=DataType.TEXT), |
|
|
Property(name="content", data_type=DataType.TEXT), |
|
|
Property(name="description", data_type=DataType.TEXT), |
|
|
Property(name="path", data_type=DataType.TEXT), |
|
|
Property(name="language", data_type=DataType.TEXT), |
|
|
Property(name="node_type", data_type=DataType.TEXT), |
|
|
Property(name="order_in_file", data_type=DataType.INT), |
|
|
Property(name="declared_entities", data_type=DataType.TEXT), |
|
|
Property(name="called_entities", data_type=DataType.TEXT), |
|
|
], |
|
|
|
|
|
vector_config=Configure.Vectors.self_provided(), |
|
|
) |
|
|
|
|
|
chunk_nodes = [node for node in nodes if node.node_type == 'chunk'] |
|
|
self.logger.info(f"Weaviate indexing {len(chunk_nodes)} chunk nodes with batch_size={self.embedding_batch_size}") |
|
|
|
|
|
|
|
|
if self.index_type != 'keyword-only': |
|
|
|
|
|
nodes_needing_embeddings = [ |
|
|
node for node in chunk_nodes |
|
|
if node.embedding is None or (isinstance(node.embedding, (list,)) and len(node.embedding) == 0) or not use_embed |
|
|
] |
|
|
|
|
|
if nodes_needing_embeddings: |
|
|
total_batches = (len(nodes_needing_embeddings) + self.embedding_batch_size - 1) // self.embedding_batch_size |
|
|
self.logger.info(f'Batch embedding {len(nodes_needing_embeddings)} nodes in {total_batches} batches') |
|
|
|
|
|
|
|
|
for i in tqdm(range(0, len(nodes_needing_embeddings), self.embedding_batch_size), |
|
|
desc="Batch embedding nodes"): |
|
|
batch_nodes = nodes_needing_embeddings[i:i + self.embedding_batch_size] |
|
|
texts_to_embed = [node.get_field_to_embed() for node in batch_nodes] |
|
|
|
|
|
|
|
|
embeddings = self.model_service.embed_chunk_code_batch(texts_to_embed) |
|
|
|
|
|
|
|
|
for node, embedding in zip(batch_nodes, embeddings): |
|
|
node.embedding = embedding |
|
|
|
|
|
|
|
|
batch_num = i // self.embedding_batch_size + 1 |
|
|
if batch_num % 10 == 0: |
|
|
self.logger.info(f"Completed batch {batch_num}/{total_batches}") |
|
|
|
|
|
self.logger.info(f"Embedding complete: processed {len(nodes_needing_embeddings)} nodes") |
|
|
else: |
|
|
self.logger.info(f"Using existing embeddings for all {len(chunk_nodes)} nodes") |
|
|
|
|
|
|
|
|
with self.collection.batch.dynamic() as batch: |
|
|
for node in tqdm(chunk_nodes, desc="Indexing nodes"): |
|
|
self.logger.debug(f'Indexing node : {node.id}') |
|
|
|
|
|
|
|
|
embedding = None |
|
|
if self.index_type != 'keyword-only': |
|
|
embedding = node.embedding |
|
|
|
|
|
|
|
|
properties = { |
|
|
"node_id": node.id, |
|
|
"name": node.name, |
|
|
"content": node.content, |
|
|
"description": node.description or "", |
|
|
"path": node.path, |
|
|
"language": node.language, |
|
|
"node_type": node.node_type, |
|
|
"order_in_file": node.order_in_file, |
|
|
"declared_entities": str(node.declared_entities), |
|
|
"called_entities": str(node.called_entities), |
|
|
} |
|
|
|
|
|
|
|
|
if self.index_type == 'keyword-only': |
|
|
|
|
|
batch.add_object(properties=properties) |
|
|
else: |
|
|
|
|
|
batch.add_object( |
|
|
properties=properties, |
|
|
vector=embedding |
|
|
) |
|
|
|
|
|
|
|
|
def query(self, query: str, n_results:int=10) -> dict: |
|
|
""" |
|
|
Perform search based on index_type: |
|
|
- 'embedding-only': pure vector search |
|
|
- 'keyword-only': pure keyword search (BM25) |
|
|
- 'hybrid': hybrid search combining both (alpha controls weighting) |
|
|
|
|
|
Weaviate's hybrid search uses: |
|
|
- alpha=0: pure keyword search (BM25) |
|
|
- alpha=1: pure vector search |
|
|
- alpha=0.5-0.8: balanced hybrid search (recommended) |
|
|
""" |
|
|
try: |
|
|
|
|
|
if self.index_type == 'keyword-only': |
|
|
|
|
|
response = self.collection.query.bm25( |
|
|
query=query, |
|
|
limit=n_results, |
|
|
return_metadata=MetadataQuery(score=True) |
|
|
) |
|
|
elif self.index_type == 'embedding-only': |
|
|
|
|
|
embedding = self.model_service.embed_query(query) |
|
|
response = self.collection.query.near_vector( |
|
|
near_vector=embedding, |
|
|
limit=n_results, |
|
|
return_metadata=MetadataQuery(distance=True) |
|
|
) |
|
|
else: |
|
|
|
|
|
embedding = self.model_service.embed_query(query) |
|
|
response = self.collection.query.hybrid( |
|
|
query=query, |
|
|
vector=embedding, |
|
|
limit=n_results, |
|
|
alpha=ALPHA_SEARCH_VALUE, |
|
|
return_metadata=MetadataQuery(distance=True, score=True) |
|
|
) |
|
|
|
|
|
|
|
|
results = { |
|
|
'ids': [[]], |
|
|
'distances': [[]], |
|
|
'metadatas': [[]], |
|
|
'documents': [[]] |
|
|
} |
|
|
|
|
|
for obj in response.objects: |
|
|
results['ids'][0].append(obj.properties['node_id']) |
|
|
results['distances'][0].append(obj.metadata.distance if obj.metadata.distance else 0.0) |
|
|
results['metadatas'][0].append({ |
|
|
'id': obj.properties['node_id'], |
|
|
'name': obj.properties['name'], |
|
|
'content': obj.properties['content'], |
|
|
'description': obj.properties['description'], |
|
|
'path': obj.properties['path'], |
|
|
'language': obj.properties['language'], |
|
|
'node_type': obj.properties['node_type'], |
|
|
'order_in_file': str(obj.properties['order_in_file']), |
|
|
'declared_entities': obj.properties['declared_entities'], |
|
|
'called_entities': obj.properties['called_entities'], |
|
|
}) |
|
|
results['documents'][0].append(obj.properties['content']) |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f'Failed to query: {e}', exc_info=True) |
|
|
raise e |
|
|
|
|
|
def __del__(self): |
|
|
"""Clean up Weaviate connection""" |
|
|
if hasattr(self, 'weaviate_client'): |
|
|
try: |
|
|
self.weaviate_client.close() |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
class LanceDBCodeIndex(BaseCodeIndex): |
|
|
"""LanceDB-based code index implementation""" |
|
|
|
|
|
def __init__(self, nodes: list, model_service, index_type: Literal['embedding-only', 'keyword-only', 'hybrid'] = 'hybrid', |
|
|
embedding_batch_size: int = 20, use_embed: bool = True, db_path: str = None): |
|
|
super().__init__(nodes, model_service, index_type, embedding_batch_size, use_embed) |
|
|
|
|
|
if not LANCEDB_AVAILABLE: |
|
|
raise ImportError("LanceDB is not available. Please install it with: pip install lancedb") |
|
|
|
|
|
|
|
|
self.db_path = db_path or LANCEDB_PATH |
|
|
self.db = lancedb.connect(self.db_path) |
|
|
self.table_name = f"code_chunks_{uuid.uuid4().hex}" |
|
|
self.table = None |
|
|
|
|
|
chunk_nodes = [node for node in nodes if node.node_type == "chunk"] |
|
|
self.logger.info(f"LanceDB indexing {len(chunk_nodes)} chunk nodes with batch_size={self.embedding_batch_size}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.index_type != "keyword-only": |
|
|
|
|
|
|
|
|
|
|
|
nodes_needing_embeddings = [] |
|
|
for node in chunk_nodes: |
|
|
needs_embedding = False |
|
|
if not use_embed: |
|
|
|
|
|
needs_embedding = True |
|
|
elif node.embedding is None: |
|
|
needs_embedding = True |
|
|
elif isinstance(node.embedding, (list, np.ndarray)) and len(node.embedding) == 0: |
|
|
needs_embedding = True |
|
|
|
|
|
if needs_embedding: |
|
|
nodes_needing_embeddings.append(node) |
|
|
|
|
|
if nodes_needing_embeddings: |
|
|
total_batches = (len(nodes_needing_embeddings) + self.embedding_batch_size - 1) // self.embedding_batch_size |
|
|
self.logger.info(f"Embedding {len(nodes_needing_embeddings)} chunks in {total_batches} batches (batch_size={self.embedding_batch_size})...") |
|
|
|
|
|
for i in tqdm(range(0, len(nodes_needing_embeddings), self.embedding_batch_size), |
|
|
desc="Batch embedding nodes"): |
|
|
batch = nodes_needing_embeddings[i:i + self.embedding_batch_size] |
|
|
texts = [n.get_field_to_embed() for n in batch] |
|
|
batch_embeds = self.model_service.embed_chunk_code_batch(texts) |
|
|
|
|
|
for n, emb in zip(batch, batch_embeds): |
|
|
n.embedding = np.array(emb, dtype=np.float32) |
|
|
|
|
|
|
|
|
batch_num = i // self.embedding_batch_size + 1 |
|
|
if batch_num % 10 == 0: |
|
|
self.logger.info(f"Completed batch {batch_num}/{total_batches}") |
|
|
|
|
|
self.logger.info(f"Embedding complete: processed {len(nodes_needing_embeddings)} chunks") |
|
|
else: |
|
|
self.logger.info(f"Using existing embeddings for all {len(chunk_nodes)} chunks") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rows = [] |
|
|
for node in chunk_nodes: |
|
|
row = { |
|
|
"node_id": node.id, |
|
|
"name": node.name, |
|
|
"content": node.content, |
|
|
"description": node.description or "", |
|
|
"path": node.path, |
|
|
"language": node.language, |
|
|
"node_type": node.node_type, |
|
|
"order_in_file": node.order_in_file, |
|
|
"declared_entities": str(node.declared_entities), |
|
|
"called_entities": str(node.called_entities), |
|
|
} |
|
|
|
|
|
|
|
|
if self.index_type != "keyword-only": |
|
|
row["vector"] = node.embedding |
|
|
|
|
|
rows.append(row) |
|
|
|
|
|
|
|
|
self.table = self.db.create_table(self.table_name, data=rows) |
|
|
self.logger.info(f"Created LanceDB table: {self.table_name}") |
|
|
|
|
|
|
|
|
|
|
|
self._create_fts_indexes() |
|
|
|
|
|
def _create_fts_indexes(self): |
|
|
""" |
|
|
Create full-text search indexes on text columns. |
|
|
|
|
|
LanceDB 0.25.x uses create_fts_index() with use_tantivy=True to support |
|
|
multiple columns. Requires tantivy package: pip install tantivy |
|
|
""" |
|
|
fts_columns = ["content", "name", "description"] |
|
|
|
|
|
try: |
|
|
|
|
|
self.table.create_fts_index(fts_columns, replace=True, use_tantivy=True) |
|
|
self.logger.info(f"Created FTS index (Tantivy) on columns: {fts_columns}") |
|
|
except Exception as e: |
|
|
self.logger.warning(f"Failed to create FTS index: {e}") |
|
|
self.logger.warning( |
|
|
"Full-text search will fall back to scanning. " |
|
|
"Ensure tantivy is installed: pip install tantivy" |
|
|
) |
|
|
|
|
|
def query(self, query: str, n_results: int=10) -> dict: |
|
|
""" |
|
|
Perform search based on index_type: |
|
|
- 'embedding-only': pure vector search |
|
|
- 'keyword-only': full-text search using LanceDB's native FTS |
|
|
- 'hybrid': combines vector similarity and full-text search with reranking |
|
|
""" |
|
|
try: |
|
|
|
|
|
if self.index_type == "keyword-only": |
|
|
|
|
|
try: |
|
|
|
|
|
df = self.table.search(query, query_type="fts").limit(n_results).to_pandas() |
|
|
except Exception as fts_error: |
|
|
self.logger.warning(f"FTS search failed, falling back to scan: {fts_error}") |
|
|
|
|
|
all_df = self.table.to_pandas() |
|
|
query_lower = query.lower() |
|
|
|
|
|
query_words = query_lower.split() |
|
|
|
|
|
def matches_query(row): |
|
|
text = f"{row.get('content', '')} {row.get('name', '')} {row.get('description', '')}".lower() |
|
|
|
|
|
return any(word in text for word in query_words) |
|
|
|
|
|
mask = all_df.apply(matches_query, axis=1) |
|
|
df = all_df[mask].head(n_results) |
|
|
|
|
|
df = df.copy() |
|
|
df['_distance'] = 0.0 |
|
|
|
|
|
|
|
|
elif self.index_type == "embedding-only": |
|
|
emb = np.array(self.model_service.embed_query(query), dtype=np.float32) |
|
|
df = self.table.search( |
|
|
emb, |
|
|
vector_column_name="vector" |
|
|
).limit(n_results).to_pandas() |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
emb = np.array(self.model_service.embed_query(query), dtype=np.float32) |
|
|
|
|
|
|
|
|
vector_limit = min(n_results * 3, 100) |
|
|
df = self.table.search( |
|
|
emb, |
|
|
vector_column_name="vector" |
|
|
).limit(vector_limit).to_pandas() |
|
|
|
|
|
if not df.empty: |
|
|
|
|
|
query_lower = query.lower() |
|
|
query_words = query_lower.split() |
|
|
|
|
|
def compute_keyword_score(row): |
|
|
"""Compute a keyword match score (higher is better)""" |
|
|
text = f"{row.get('content', '')} {row.get('name', '')} {row.get('description', '')}".lower() |
|
|
score = 0 |
|
|
|
|
|
if query_lower in text: |
|
|
score += 10 |
|
|
|
|
|
for word in query_words: |
|
|
if word in text: |
|
|
score += 1 |
|
|
|
|
|
if word in str(row.get('name', '')).lower(): |
|
|
score += 2 |
|
|
return score |
|
|
|
|
|
|
|
|
df = df.copy() |
|
|
df['_keyword_score'] = df.apply(compute_keyword_score, axis=1) |
|
|
|
|
|
|
|
|
max_dist = df['_distance'].max() if df['_distance'].max() > 0 else 1.0 |
|
|
df['_vector_score'] = 1.0 - (df['_distance'] / max_dist) |
|
|
|
|
|
|
|
|
|
|
|
alpha = 0.7 |
|
|
max_keyword = df['_keyword_score'].max() if df['_keyword_score'].max() > 0 else 1.0 |
|
|
df['_combined_score'] = ( |
|
|
alpha * df['_vector_score'] + |
|
|
(1 - alpha) * (df['_keyword_score'] / max_keyword) |
|
|
) |
|
|
|
|
|
|
|
|
df = df.sort_values('_combined_score', ascending=False).head(n_results) |
|
|
|
|
|
|
|
|
results = { |
|
|
"ids": [[]], |
|
|
"distances": [[]], |
|
|
"metadatas": [[]], |
|
|
"documents": [[]], |
|
|
} |
|
|
|
|
|
for _, row in df.iterrows(): |
|
|
results["ids"][0].append(row["node_id"]) |
|
|
results["documents"][0].append(row["content"]) |
|
|
results["distances"][0].append(float(row.get("_distance", 0))) |
|
|
|
|
|
results["metadatas"][0].append({ |
|
|
"id": row["node_id"], |
|
|
"name": row["name"], |
|
|
"content": row["content"], |
|
|
"description": row["description"], |
|
|
"path": row["path"], |
|
|
"language": row["language"], |
|
|
"node_type": row["node_type"], |
|
|
"order_in_file": str(row["order_in_file"]), |
|
|
"declared_entities": row["declared_entities"], |
|
|
"called_entities": row["called_entities"], |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"Query failed: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def __del__(self): |
|
|
"""Clean up resources""" |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
def CodeIndex( |
|
|
nodes: list, |
|
|
model_service, |
|
|
index_type: Literal['embedding-only', 'keyword-only', 'hybrid'] = 'hybrid', |
|
|
embedding_batch_size: int = 20, |
|
|
use_embed: bool = True, |
|
|
backend: Literal['weaviate', 'lancedb'] = 'weaviate', |
|
|
db_path: str = None, |
|
|
host: str = None, |
|
|
port: int = None, |
|
|
grpc_port: int = None |
|
|
) -> BaseCodeIndex: |
|
|
""" |
|
|
Factory function to create a CodeIndex instance. |
|
|
|
|
|
Args: |
|
|
nodes: List of nodes to index |
|
|
model_service: Service for embedding generation |
|
|
index_type: Type of search ('embedding-only', 'keyword-only', or 'hybrid') |
|
|
embedding_batch_size: Batch size for embedding generation |
|
|
use_embed: Whether to use pre-computed embeddings |
|
|
backend: Which backend to use ('weaviate' or 'lancedb') |
|
|
db_path: Path for LanceDB (only used with 'lancedb' backend) |
|
|
host: Weaviate host (only used with 'weaviate' backend) |
|
|
port: Weaviate port (only used with 'weaviate' backend) |
|
|
grpc_port: Weaviate gRPC port (only used with 'weaviate' backend) |
|
|
|
|
|
Returns: |
|
|
BaseCodeIndex: Either WeaviateCodeIndex or LanceDBCodeIndex instance |
|
|
""" |
|
|
if backend == 'lancedb': |
|
|
return LanceDBCodeIndex( |
|
|
nodes=nodes, |
|
|
model_service=model_service, |
|
|
index_type=index_type, |
|
|
embedding_batch_size=embedding_batch_size, |
|
|
use_embed=use_embed, |
|
|
db_path=db_path |
|
|
) |
|
|
elif backend == 'weaviate': |
|
|
return WeaviateCodeIndex( |
|
|
nodes=nodes, |
|
|
model_service=model_service, |
|
|
index_type=index_type, |
|
|
embedding_batch_size=embedding_batch_size, |
|
|
use_embed=use_embed, |
|
|
host=host, |
|
|
port=port, |
|
|
grpc_port=grpc_port |
|
|
) |
|
|
else: |
|
|
return WeaviateCodeIndex( |
|
|
nodes=nodes, |
|
|
model_service=model_service, |
|
|
index_type=index_type, |
|
|
embedding_batch_size=embedding_batch_size, |
|
|
use_embed=use_embed, |
|
|
host=host, |
|
|
port=port, |
|
|
grpc_port=grpc_port |
|
|
) |
|
|
|
|
|
|
|
|
|