|
|
import os |
|
|
import json |
|
|
import asyncio |
|
|
from uuid import uuid4 |
|
|
from datetime import datetime |
|
|
from typing import List, Union, Optional, Sequence, Dict, Any, Tuple |
|
|
from llama_index.core.schema import NodeWithScore, TextNode, ImageNode, RelatedNodeInfo |
|
|
|
|
|
from .rag_config import RAGConfig |
|
|
from .readers import LLamaIndexReader, MultimodalReader |
|
|
from .indexings import IndexFactory, BaseIndexWrapper |
|
|
from .chunkers import ChunkFactory |
|
|
from .embeddings import EmbeddingFactory, EmbeddingProvider |
|
|
from .retrievers import RetrieverFactory, BaseRetrieverWrapper |
|
|
from .postprocessors import PostprocessorFactory |
|
|
from .indexings.base import IndexType |
|
|
from .retrievers.base import RetrieverType |
|
|
from .schema import Corpus, ChunkMetadata, IndexMetadata, Query, RagResult, ImageChunk, TextChunk |
|
|
from evoagentx.storages.base import StorageHandler |
|
|
from evoagentx.storages.schema import IndexStore |
|
|
from evoagentx.models.base_model import BaseLLM |
|
|
from evoagentx.core.logging import logger |
|
|
|
|
|
|
|
|
class RAGEngine: |
|
|
def __init__(self, config: RAGConfig, storage_handler: StorageHandler, llm: Optional[BaseLLM] = None): |
|
|
self.config = config |
|
|
self.storage_handler = storage_handler |
|
|
self.embedding_factory = EmbeddingFactory() |
|
|
self.index_factory = IndexFactory() |
|
|
self.chunk_factory = ChunkFactory() |
|
|
self.retriever_factory = RetrieverFactory() |
|
|
self.postprocessor_factory = PostprocessorFactory() |
|
|
self.llm = llm |
|
|
|
|
|
|
|
|
logger.info(f"RAGEngine modality config: {self.config.modality}") |
|
|
|
|
|
if self.config.modality == "multimodal": |
|
|
self.chunk_class = ImageChunk |
|
|
else: |
|
|
self.chunk_class = TextChunk |
|
|
|
|
|
|
|
|
if self.config.modality == "multimodal": |
|
|
self.reader = MultimodalReader( |
|
|
recursive=self.config.reader.recursive, |
|
|
exclude_hidden=self.config.reader.exclude_hidden, |
|
|
num_files_limits=self.config.reader.num_files_limit, |
|
|
errors=self.config.reader.errors |
|
|
) |
|
|
else: |
|
|
self.reader = LLamaIndexReader( |
|
|
recursive=self.config.reader.recursive, |
|
|
exclude_hidden=self.config.reader.exclude_hidden, |
|
|
num_workers=self.config.num_workers, |
|
|
num_files_limits=self.config.reader.num_files_limit, |
|
|
custom_metadata_function=self.config.reader.custom_metadata_function, |
|
|
extern_file_extractor=self.config.reader.extern_file_extractor, |
|
|
errors=self.config.reader.errors, |
|
|
encoding=self.config.reader.encoding |
|
|
) |
|
|
|
|
|
|
|
|
self.embed_model = self.embedding_factory.create( |
|
|
provider=self.config.embedding.provider, |
|
|
model_config=self.config.embedding.model_dump(exclude_unset=True), |
|
|
) |
|
|
|
|
|
|
|
|
if (self.storage_handler.vector_store is not None) and (self.embed_model.dimensions is not None): |
|
|
if self.storage_handler.storageConfig.vectorConfig.dimensions != self.embed_model.dimensions: |
|
|
logger.warning("The dimensions in vector_store is not equal with embed_model. Reiniliaze vector_store.") |
|
|
self.storage_handler.storageConfig.vectorConfig.dimensions = self.embed_model.dimensions |
|
|
self.storage_handler._init_vector_store() |
|
|
|
|
|
|
|
|
if self.config.modality == "multimodal": |
|
|
self.chunker = None |
|
|
else: |
|
|
self.chunker = self.chunk_factory.create( |
|
|
strategy=self.config.chunker.strategy, |
|
|
embed_model=self.embed_model.get_embedding_model(), |
|
|
chunker_config={ |
|
|
"chunk_size": self.config.chunker.chunk_size, |
|
|
"chunk_overlap": self.config.chunker.chunk_overlap, |
|
|
"max_chunks": self.config.chunker.max_chunks |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
self.indices: Dict[str, Dict[str, BaseIndexWrapper]] = {} |
|
|
self.retrievers: Dict[str, Dict[str, BaseRetrieverWrapper]] = {} |
|
|
|
|
|
def read(self, file_paths: Union[Sequence[str], str], |
|
|
exclude_files: Optional[Union[str, List, Tuple, Sequence]] = None, |
|
|
filter_file_by_suffix: Optional[Union[str, List, Tuple, Sequence]] = None, |
|
|
merge_by_file: bool = False, |
|
|
show_progress: bool = False, |
|
|
corpus_id: str = None) -> Corpus: |
|
|
"""Load and chunk documents from files. |
|
|
|
|
|
Reads files from specified paths, processes them into documents, and chunks them into a Corpus. |
|
|
|
|
|
Args: |
|
|
file_paths (Union[Sequence[str], str]): Path(s) to files or directories. |
|
|
exclude_files (Optional[Union[str, List, Tuple, Sequence]]): Files to exclude. |
|
|
filter_file_by_suffix (Optional[Union[str, List, Tuple, Sequence]]): Filter files by suffix (e.g., '.pdf'). |
|
|
merge_by_file (bool): Merge documents by file. |
|
|
show_progress (bool): Show loading progress. |
|
|
corpus_id (Optional[str]): Identifier for the corpus. Defaults to a UUID if None. |
|
|
|
|
|
Returns: |
|
|
Corpus: The chunked corpus containing processed document chunks. |
|
|
|
|
|
Raises: |
|
|
Exception: If document reading or chunking fails. |
|
|
""" |
|
|
try: |
|
|
corpus_id = corpus_id or str(uuid4()) |
|
|
documents = self.reader.load( |
|
|
file_paths=file_paths, |
|
|
exclude_files=exclude_files, |
|
|
filter_file_by_suffix=filter_file_by_suffix, |
|
|
merge_by_file=merge_by_file, |
|
|
show_progress=show_progress |
|
|
) |
|
|
if self.config.modality == "multimodal": |
|
|
|
|
|
image_chunks = [] |
|
|
for doc in documents: |
|
|
|
|
|
image_path = getattr(doc, 'image_path', None) or doc.metadata.get('file_path') |
|
|
image_mimetype = getattr(doc, 'image_mimetype', None) |
|
|
|
|
|
image_chunk = self.chunk_class( |
|
|
image_path=image_path, |
|
|
image_mimetype=image_mimetype, |
|
|
chunk_id=doc.metadata.get('file_name', f'img_{len(image_chunks)}'), |
|
|
metadata=ChunkMetadata( |
|
|
doc_id=doc.metadata.get('file_name', f'doc_{len(image_chunks)}'), |
|
|
corpus_id=corpus_id, |
|
|
**doc.metadata |
|
|
) |
|
|
) |
|
|
image_chunks.append(image_chunk) |
|
|
corpus = Corpus(chunks=image_chunks, corpus_id=corpus_id) |
|
|
logger.info(f"Read {len(documents)} multimodal documents (no chunking) for corpus {corpus_id}") |
|
|
else: |
|
|
corpus = self.chunker.chunk(documents) |
|
|
corpus.corpus_id = corpus_id |
|
|
logger.info(f"Read {len(documents)} documents and created {len(corpus.chunks)} chunks for corpus {corpus_id}") |
|
|
return corpus |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to read documents for corpus {corpus_id}: {str(e)}") |
|
|
raise |
|
|
|
|
|
def add(self, index_type: str, nodes: Union[Corpus, List[NodeWithScore], List[TextNode], List[ImageNode]], |
|
|
corpus_id: str = None) -> None: |
|
|
"""Add nodes to an index for a specific corpus. |
|
|
|
|
|
Initializes an index if it doesn't exist and inserts nodes, updating metadata with corpus_id and index_type. |
|
|
|
|
|
Args: |
|
|
index_type (str): Type of index (e.g., VECTOR, GRAPH). |
|
|
nodes (Union[Corpus, List[NodeWithScore], List[TextNode]]): Nodes or Corpus to add. |
|
|
corpus_id (str, optional): Identifier for the corpus. Defaults to a UUID if None. |
|
|
|
|
|
Return: |
|
|
return a sequence with id of each added node. |
|
|
|
|
|
Raises: |
|
|
Exception: If index creation or node insertion fails. |
|
|
""" |
|
|
try: |
|
|
corpus_id = corpus_id or str(uuid4()) |
|
|
if corpus_id not in self.indices: |
|
|
self.indices[corpus_id] = {} |
|
|
self.retrievers[corpus_id] = {} |
|
|
|
|
|
if index_type not in self.indices[corpus_id]: |
|
|
index = self.index_factory.create( |
|
|
index_type=index_type, |
|
|
embed_model=self.embed_model.get_embedding_model(), |
|
|
storage_handler=self.storage_handler, |
|
|
index_config=self.config.index.model_dump(exclude_unset=True) if self.config.index else {}, |
|
|
llm=self.llm, |
|
|
) |
|
|
self.indices[corpus_id][index_type] = index |
|
|
self.retrievers[corpus_id][index_type] = self.retriever_factory.create( |
|
|
retriever_type=self.config.retrieval.retrivel_type, |
|
|
llm=self.llm, |
|
|
index=index.get_index(), |
|
|
graph_store=index.get_index().storage_context.graph_store, |
|
|
embed_model=self.embed_model.get_embedding_model(), |
|
|
query=Query(query_str="", top_k=self.config.retrieval.top_k if self.config.retrieval else 5), |
|
|
storage_handler=self.storage_handler, |
|
|
chunk_class=self.chunk_class |
|
|
) |
|
|
|
|
|
nodes_to_insert = nodes.to_llama_nodes() if isinstance(nodes, Corpus) else nodes |
|
|
for node in nodes_to_insert: |
|
|
node.metadata.update({"corpus_id": corpus_id, "index_type": index_type}) |
|
|
nodes_ids = self.indices[corpus_id][index_type].insert_nodes(nodes_to_insert) |
|
|
logger.info(f"Added {len(nodes_to_insert)} nodes to {index_type} index for corpus {corpus_id}") |
|
|
return nodes_ids |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to add nodes to {index_type} index for corpus {corpus_id}: {str(e)}") |
|
|
return [] |
|
|
|
|
|
def delete(self, corpus_id: str, index_type: Optional[str] = None, |
|
|
node_ids: Optional[Union[str, List[str]]] = None, |
|
|
metadata_filters: Optional[Dict[str, Any]] = None) -> None: |
|
|
"""Delete nodes or an entire index from a corpus. |
|
|
|
|
|
Removes specific nodes by ID or metadata filters, or deletes the entire index if no filters are provided. |
|
|
|
|
|
Args: |
|
|
corpus_id (str): Identifier for the corpus. |
|
|
index_type (Optional[IndexType]): Specific index type to delete from. If None, affects all indices. |
|
|
node_ids (Union[str, Optional[List[str]]]): List of node IDs to delete. |
|
|
metadata_filters (Optional[Dict[str, Any]]): Metadata filters to select nodes for deletion. |
|
|
|
|
|
Raises: |
|
|
Exception: If deletion fails. |
|
|
""" |
|
|
try: |
|
|
if corpus_id not in self.indices: |
|
|
logger.warning(f"No indices found for corpus {corpus_id}") |
|
|
return |
|
|
|
|
|
target_indices = [index_type] if index_type else self.indices[corpus_id].keys() |
|
|
for idx_type in list(target_indices): |
|
|
if idx_type not in self.indices[corpus_id]: |
|
|
logger.warning(f"Index type {idx_type} not found for corpus {corpus_id}") |
|
|
continue |
|
|
|
|
|
index = self.indices[corpus_id][idx_type] |
|
|
if node_ids or metadata_filters: |
|
|
|
|
|
node_ids_list = [node_ids] if isinstance(node_ids, str) else node_ids |
|
|
index.delete_nodes(node_ids=node_ids_list, metadata_filters=metadata_filters) |
|
|
logger.info(f"Deleted nodes from {idx_type} index for corpus {corpus_id}") |
|
|
else: |
|
|
|
|
|
index.clear() |
|
|
del self.indices[corpus_id][idx_type] |
|
|
del self.retrievers[corpus_id][idx_type] |
|
|
logger.info(f"Deleted entire {idx_type} index for corpus {corpus_id}") |
|
|
|
|
|
|
|
|
if not self.indices[corpus_id]: |
|
|
del self.indices[corpus_id] |
|
|
del self.retrievers[corpus_id] |
|
|
logger.info(f"Removed empty corpus {corpus_id}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to delete from corpus {corpus_id}, index {index_type}: {str(e)}") |
|
|
raise |
|
|
|
|
|
def clear(self, corpus_id: Optional[str] = None) -> None: |
|
|
"""Clear all indices for a specific corpus or all corpora. |
|
|
|
|
|
Args: |
|
|
corpus_id (Optional[str]): Specific corpus to clear. If None, clears all corpora. |
|
|
|
|
|
Raises: |
|
|
Exception: If clearing fails. |
|
|
""" |
|
|
try: |
|
|
target_corpora = [corpus_id] if corpus_id else list(self.indices.keys()) |
|
|
for cid in target_corpora: |
|
|
if cid not in self.indices: |
|
|
logger.warning(f"No indices found for corpus {cid}") |
|
|
continue |
|
|
|
|
|
for idx_type in list(self.indices[cid].keys()): |
|
|
index = self.indices[cid][idx_type] |
|
|
index.clear() |
|
|
del self.indices[cid][idx_type] |
|
|
del self.retrievers[cid][idx_type] |
|
|
logger.info(f"Cleared {idx_type} index for corpus {cid}") |
|
|
|
|
|
|
|
|
del self.indices[cid] |
|
|
del self.retrievers[cid] |
|
|
logger.info(f"Cleared corpus {cid}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to clear indices for corpus {corpus_id or 'all'}: {str(e)}") |
|
|
raise |
|
|
|
|
|
def save(self, output_path: Optional[str] = None, corpus_id: Optional[str] = None, |
|
|
index_type: Optional[str] = None, table: Optional[str] = None, |
|
|
graph_exported: bool = False) -> None: |
|
|
"""Save indices to files or database. |
|
|
|
|
|
Serializes corpus chunks to JSONL files and metadata to JSON files if output_path is provided, |
|
|
or saves to the SQLite database via StorageHandler if output_path is None. |
|
|
|
|
|
Args: |
|
|
output_path (Optional[str]): Directory to save JSONL and JSON files. If None, saves to database. |
|
|
corpus_id (Optional[str]): Specific corpus to save. If None, saves all corpora. |
|
|
index_type (Optional[str]): Specific index type to save. If None, saves all indices. |
|
|
table (Optional[str]): Database table name for index data. Defaults to 'indexing' if None. |
|
|
graph_exported (bool): If True, export graph nodes and relations for graph indices. Defaults to False. |
|
|
|
|
|
Raises: |
|
|
Exception: If saving fails or file operations encounter errors. |
|
|
""" |
|
|
try: |
|
|
target_corpora = [corpus_id] if corpus_id else list(self.indices.keys()) |
|
|
table = table or "indexing" |
|
|
|
|
|
for cid in target_corpora: |
|
|
if cid not in self.indices: |
|
|
logger.warning(f"No indices found for corpus {cid}") |
|
|
continue |
|
|
|
|
|
target_indices = [index_type] if index_type and index_type in self.indices[cid] else self.indices[cid].keys() |
|
|
for idx_type in target_indices: |
|
|
index = self.indices[cid][idx_type] |
|
|
|
|
|
|
|
|
if idx_type == IndexType.GRAPH and not graph_exported: |
|
|
logger.warning(f"Skipping save for graph index {idx_type} in corpus {cid} as graph_exported is False") |
|
|
continue |
|
|
|
|
|
|
|
|
if idx_type == IndexType.GRAPH and graph_exported: |
|
|
index.build_kv_store() |
|
|
|
|
|
|
|
|
chunks = [ |
|
|
self.chunk_class.from_llama_node(node_data) |
|
|
for node_id, node_data in index.id_to_node.items() |
|
|
] |
|
|
corpus = Corpus(chunks=chunks, corpus_id=cid) |
|
|
|
|
|
|
|
|
vector_config = self.storage_handler.storageConfig.vectorConfig.model_dump() if self.storage_handler.storageConfig.vectorConfig else {} |
|
|
graph_config = self.storage_handler.storageConfig.graphConfig.model_dump() if self.storage_handler.storageConfig.graphConfig else {} |
|
|
metadata = IndexMetadata( |
|
|
corpus_id=cid, |
|
|
index_type=idx_type, |
|
|
collection_name=vector_config.get("qdrant_collection_name", "default_collection"), |
|
|
dimension=self.embed_model.dimensions, |
|
|
vector_db_type=vector_config.get("vector_name", None), |
|
|
graph_db_type=graph_config.get("graph_name", None), |
|
|
embedding_model_name=self.config.embedding.model_name, |
|
|
date=str(datetime.now()), |
|
|
) |
|
|
|
|
|
if output_path: |
|
|
|
|
|
os.makedirs(output_path, exist_ok=True) |
|
|
safe_cid = "".join(c if c.isalnum() or c in ["-", "_"] else "_" for c in cid) |
|
|
safe_idx_type = "".join(c if c.isalnum() or c in ["-", "_"] else "_" for c in idx_type) |
|
|
nodes_file = os.path.join(output_path, f"{safe_cid}_{safe_idx_type}_nodes.jsonl") |
|
|
metadata_file = os.path.join(output_path, f"{safe_cid}_{safe_idx_type}_metadata.json") |
|
|
|
|
|
|
|
|
corpus.to_jsonl(nodes_file, indent=0) |
|
|
logger.info(f"Saved {len(corpus.chunks)} chunks to {nodes_file}") |
|
|
|
|
|
|
|
|
with open(metadata_file, "w", encoding="utf-8") as f: |
|
|
json.dump(metadata.model_dump(), f, indent=2, ensure_ascii=False) |
|
|
logger.info(f"Saved metadata to {metadata_file}") |
|
|
else: |
|
|
|
|
|
index_data = { |
|
|
"corpus_id": cid, |
|
|
"content": corpus.model_dump(), |
|
|
"date": str(datetime.now()), |
|
|
"metadata": metadata.model_dump() |
|
|
} |
|
|
self.storage_handler.save_index(index_data, table=table) |
|
|
logger.info(f"Saved {idx_type} index with {len(corpus.chunks)} chunks for corpus {cid} to database table {table}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to save indices for corpus {corpus_id or 'all'}: {str(e)}") |
|
|
raise |
|
|
|
|
|
def load(self, source: Optional[str] = None, corpus_id: Optional[str] = None, |
|
|
index_type: Optional[str] = None, table: Optional[str] = None) -> None: |
|
|
"""Load indices from files or database. |
|
|
|
|
|
Reconstructs indices and retrievers from JSONL/JSON files or SQLite database records. |
|
|
Validates the embedding model name and dimension before reinitializing the embedding model. |
|
|
|
|
|
Args: |
|
|
source (Optional[str]): Directory containing JSONL/JSON files. If None, loads from database. |
|
|
corpus_id (Optional[str]): Specific corpus to load. If None, loads all corpora. |
|
|
index_type (Optional[str]): Specific index type to load. If None, loads all indices. |
|
|
table (Optional[str]): Database table name for index data. Defaults to 'indexing' if None. |
|
|
|
|
|
Returns: |
|
|
The Sequence with id of loaded chunk. |
|
|
|
|
|
Raises: |
|
|
Exception: If loading fails due to file or database errors, invalid data, or unsupported embedding model/dimension. |
|
|
|
|
|
Warning: |
|
|
Try to call this function may cause some Bugs, when you load the nodes from file or database storage systems at twice. |
|
|
Because All the indexing share the same storage backend from storageHandler. |
|
|
For example: |
|
|
The vector database (.e.g Faiss) can insert again, even thougt there is a same node. |
|
|
""" |
|
|
try: |
|
|
table = table or "indexing" |
|
|
config_dimension = self.storage_handler.storageConfig.vectorConfig.dimensions |
|
|
loaded_chunk_ids: List[str] = [] |
|
|
|
|
|
if source: |
|
|
|
|
|
if not os.path.exists(source): |
|
|
logger.error(f"Source directory {source} does not exist") |
|
|
raise FileNotFoundError(f"Source directory {source} does not exist") |
|
|
|
|
|
for file_name in os.listdir(source): |
|
|
if not file_name.endswith("_metadata.json"): |
|
|
continue |
|
|
parts = file_name.split("_") |
|
|
if len(parts) < 3: |
|
|
logger.warning(f"Skipping invalid metadata file: {file_name}") |
|
|
continue |
|
|
cid = "_".join(parts[:-2]) |
|
|
idx_type = parts[-2] |
|
|
|
|
|
if (corpus_id and corpus_id != cid) or (index_type and index_type != idx_type): |
|
|
continue |
|
|
|
|
|
metadata_file = os.path.join(source, file_name) |
|
|
nodes_file = os.path.join(source, f"{cid}_{idx_type}_nodes.jsonl") |
|
|
|
|
|
|
|
|
with open(metadata_file, "r", encoding="utf-8") as f: |
|
|
metadata = IndexMetadata.model_validate(json.load(f)) |
|
|
|
|
|
|
|
|
if not self.embed_model.validate_model(self.config.embedding.provider, metadata.embedding_model_name): |
|
|
raise ValueError( |
|
|
f"Embedding model '{metadata.embedding_model_name}' is not supported by provider '{self.config.embedding.provider}'. " |
|
|
f"Supported models: {EmbeddingProvider.SUPPORTED_MODELS.get(self.config.embedding.provider, [])}" |
|
|
) |
|
|
|
|
|
|
|
|
if metadata.dimension != config_dimension: |
|
|
raise ValueError( |
|
|
f"Embedding dimension {metadata.dimension} in metadata does not match configured dimension {config_dimension}." |
|
|
) |
|
|
|
|
|
|
|
|
if not os.path.exists(nodes_file): |
|
|
logger.warning(f"Nodes file {nodes_file} not found for metadata {metadata_file}") |
|
|
continue |
|
|
corpus = Corpus.from_jsonl(nodes_file, corpus_id=cid) |
|
|
|
|
|
|
|
|
if metadata.embedding_model_name != self.config.embedding.model_name: |
|
|
logger.info(f"Reinitializing embedding model to {metadata.embedding_model_name}") |
|
|
self.embed_model = self.embedding_factory.create( |
|
|
provider=self.config.embedding.provider, |
|
|
model_config=self.config.embedding.model_dump(exclude_unset=True) |
|
|
) |
|
|
|
|
|
|
|
|
chunk_ids = self._load_index(corpus, cid, idx_type) |
|
|
loaded_chunk_ids.extend(chunk_ids) |
|
|
logger.info(f"Loaded {idx_type} index with {len(corpus.chunks)} chunks for corpus {cid} from {nodes_file}") |
|
|
else: |
|
|
|
|
|
records = self.storage_handler.load(tables=[table]).get(table, []) |
|
|
|
|
|
if not records: |
|
|
logger.warning(f"No records found in table {table}") |
|
|
return |
|
|
|
|
|
for record in records: |
|
|
parsed = self.storage_handler.parse_result(record, IndexStore) |
|
|
cid = parsed["corpus_id"] |
|
|
idx_type = parsed["metadata"]["index_type"] |
|
|
if (corpus_id and corpus_id != cid) or (index_type and index_type != idx_type): |
|
|
continue |
|
|
|
|
|
|
|
|
chunks = [] |
|
|
for chunk_data in parsed["content"]["chunks"]: |
|
|
metadata = ChunkMetadata.model_validate(chunk_data["metadata"]) |
|
|
|
|
|
if self.config.modality == "multimodal": |
|
|
|
|
|
chunk = ImageChunk( |
|
|
chunk_id=chunk_data["chunk_id"], |
|
|
image_path=chunk_data["image_path"], |
|
|
image_mimetype=chunk_data.get("image_mimetype"), |
|
|
metadata=metadata, |
|
|
embedding=chunk_data["embedding"], |
|
|
excluded_embed_metadata_keys=chunk_data["excluded_embed_metadata_keys"], |
|
|
excluded_llm_metadata_keys=chunk_data["excluded_llm_metadata_keys"], |
|
|
relationships={k: RelatedNodeInfo(**v) for k, v in chunk_data["relationships"].items()} |
|
|
) |
|
|
else: |
|
|
|
|
|
chunk = TextChunk( |
|
|
chunk_id=chunk_data["chunk_id"], |
|
|
text=chunk_data["text"], |
|
|
metadata=metadata, |
|
|
embedding=chunk_data["embedding"], |
|
|
start_char_idx=chunk_data["start_char_idx"], |
|
|
end_char_idx=chunk_data["end_char_idx"], |
|
|
excluded_embed_metadata_keys=chunk_data["excluded_embed_metadata_keys"], |
|
|
excluded_llm_metadata_keys=chunk_data["excluded_llm_metadata_keys"], |
|
|
relationships={k: RelatedNodeInfo(**v) for k, v in chunk_data["relationships"].items()} |
|
|
) |
|
|
chunks.append(chunk) |
|
|
|
|
|
corpus = Corpus( |
|
|
chunks=chunks, |
|
|
corpus_id=cid, |
|
|
metadata=IndexMetadata.model_validate(parsed["metadata"]) |
|
|
) |
|
|
|
|
|
|
|
|
metadata = IndexMetadata.model_validate(parsed["metadata"]) |
|
|
if not self.embed_model.validate_model(self.config.embedding.provider, metadata.embedding_model_name): |
|
|
raise ValueError( |
|
|
f"Embedding model '{metadata.embedding_model_name}' is not supported by provider '{self.config.embedding.provider}'. " |
|
|
f"Supported models: {EmbeddingProvider.SUPPORTED_MODELS.get(self.config.embedding.provider, [])}" |
|
|
) |
|
|
|
|
|
|
|
|
if metadata.dimension != config_dimension: |
|
|
raise ValueError( |
|
|
f"Embedding dimension {metadata.dimension} in metadata does not match configured dimension {config_dimension}." |
|
|
) |
|
|
|
|
|
|
|
|
if metadata.embedding_model_name != self.config.embedding.model_name: |
|
|
logger.info(f"Reinitializing embedding model to {metadata.embedding_model_name}") |
|
|
self.embed_model = self.embedding_factory.create( |
|
|
provider=self.config.embedding.provider, |
|
|
model_config=self.config.embedding.model_dump(exclude_unset=True) |
|
|
) |
|
|
|
|
|
|
|
|
chunk_ids = self._load_index(corpus, cid, idx_type) |
|
|
loaded_chunk_ids.extend(chunk_ids) |
|
|
logger.info(f"Loaded {idx_type} index with {len(corpus.chunks)} chunks for corpus {cid} from database table {table}") |
|
|
|
|
|
return loaded_chunk_ids |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load indices: {str(e)}") |
|
|
raise |
|
|
|
|
|
def _load_index(self, corpus: Corpus, corpus_id: str, index_type: str) -> Sequence[str]: |
|
|
"""Helper method to load an index and its retriever.""" |
|
|
try: |
|
|
if corpus_id not in self.indices: |
|
|
self.indices[corpus_id] = {} |
|
|
self.retrievers[corpus_id] = {} |
|
|
|
|
|
if index_type not in self.indices[corpus_id]: |
|
|
index = self.index_factory.create( |
|
|
index_type=index_type, |
|
|
embed_model=self.embed_model.get_embedding_model(), |
|
|
storage_handler=self.storage_handler, |
|
|
index_config=self.config.index.model_dump(exclude_unset=True) if self.config.index else {}, |
|
|
llm=self.llm |
|
|
) |
|
|
self.indices[corpus_id][index_type] = index |
|
|
retriever_type = RetrieverType.GRAPH if index_type == IndexType.GRAPH else RetrieverType.VECTOR |
|
|
self.retrievers[corpus_id][index_type] = self.retriever_factory.create( |
|
|
retriever_type=retriever_type, |
|
|
llm=self.llm, |
|
|
index=index.get_index(), |
|
|
graph_store=index.get_index().storage_context.graph_store, |
|
|
embed_model=self.embed_model.get_embedding_model(), |
|
|
query=Query(query_str="", top_k=self.config.retrieval.top_k if self.config.retrieval else 5), |
|
|
storage_handler=self.storage_handler |
|
|
) |
|
|
|
|
|
nodes = corpus.to_llama_nodes() |
|
|
|
|
|
for node in nodes: |
|
|
node.metadata.update({"corpus_id": corpus_id, "index_type": index_type}) |
|
|
|
|
|
chunk_ids = self.indices[corpus_id][index_type].load(nodes) |
|
|
logger.info(f"Inserted {len(nodes)} nodes into {index_type} index for corpus {corpus_id}") |
|
|
return chunk_ids |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load index for corpus {corpus_id}, index_type {index_type}: {str(e)}") |
|
|
raise |
|
|
|
|
|
async def aget(self, corpus_id: str, index_type: str, node_ids: List[str]) -> List[Union[TextChunk, ImageChunk]]: |
|
|
"""Retrieve chunks by node_ids from the index.""" |
|
|
try: |
|
|
chunks = await self.indices[corpus_id][index_type].get(node_ids=node_ids) |
|
|
logger.info(f"Retrieved {len(chunks)} chunks for node_ids: {node_ids}") |
|
|
return chunks |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to get chunks: {str(e)}") |
|
|
return [] |
|
|
|
|
|
async def query_async(self, query: Union[str, Query], corpus_id: Optional[str] = None, |
|
|
query_transforms: Optional[List] = None) -> RagResult: |
|
|
"""Execute a query across indices and return processed results asynchronously. |
|
|
|
|
|
Performs query preprocessing, asynchronous retrieval, and post-processing. |
|
|
|
|
|
Args: |
|
|
query (Union[str, Query]): Query string or Query object. |
|
|
corpus_id (Optional[str]): Specific corpus to query. If None, queries all corpora. |
|
|
query_transforms (Optional[List]): Query Transforms is used to augment query in pre-processing. |
|
|
|
|
|
Returns: |
|
|
RagResult: Retrieved chunks with scores and metadata. |
|
|
|
|
|
Raises: |
|
|
Exception: If query processing fails. |
|
|
""" |
|
|
try: |
|
|
if isinstance(query, str): |
|
|
query = Query(query_str=query, top_k=self.config.retrieval.top_k) |
|
|
|
|
|
if not self.indices or (corpus_id and corpus_id not in self.indices): |
|
|
logger.warning(f"No indices found for corpus {corpus_id or 'any'}") |
|
|
return RagResult(corpus=Corpus(chunks=[]), scores=[], metadata={"query": query.query_str}) |
|
|
|
|
|
|
|
|
if query_transforms and query_transforms is not None: |
|
|
for t in query_transforms: |
|
|
query = t(query) |
|
|
|
|
|
results = [] |
|
|
target_corpora = [corpus_id] if corpus_id else self.indices.keys() |
|
|
|
|
|
|
|
|
tasks = [] |
|
|
for cid in target_corpora: |
|
|
for idx_type, retriever in self.retrievers[cid].items(): |
|
|
if query.metadata_filters and query.metadata_filters.get("index_type") and \ |
|
|
query.metadata_filters["index_type"] != idx_type: |
|
|
continue |
|
|
|
|
|
task = retriever.aretrieve( |
|
|
Query( |
|
|
query_str=query.query_str, |
|
|
top_k=query.top_k or self.config.retrieval.top_k, |
|
|
similarity_cutoff=query.similarity_cutoff, |
|
|
keyword_filters=query.keyword_filters, |
|
|
metadata_filters=query.metadata_filters |
|
|
) |
|
|
) |
|
|
tasks.append((task, cid, idx_type)) |
|
|
|
|
|
|
|
|
retrieval_tasks = [task for task, _, _ in tasks] |
|
|
retrieval_results = await asyncio.gather(*retrieval_tasks, return_exceptions=True) |
|
|
|
|
|
|
|
|
for (_, cid, idx_type), result in zip(tasks, retrieval_results): |
|
|
if isinstance(result, Exception): |
|
|
logger.error(f"Retrieval failed for {idx_type} in corpus {cid}: {str(result)}") |
|
|
else: |
|
|
results.append(result) |
|
|
logger.info(f"Retrieved {len(result.corpus.chunks)} chunks from {idx_type} retriever for corpus {cid}") |
|
|
|
|
|
if not results: |
|
|
return RagResult(corpus=Corpus(chunks=[]), scores=[], metadata={"query": query.query_str}) |
|
|
|
|
|
|
|
|
query.similarity_cutoff = self.config.retrieval.similarity_cutoff if query.similarity_cutoff is None else query.similarity_cutoff |
|
|
query.keyword_filters = self.config.retrieval.keyword_filters if query.keyword_filters is None else query.keyword_filters |
|
|
|
|
|
postprocessor = self.postprocessor_factory.create( |
|
|
self.config.retrieval.postprocessor_type, |
|
|
query=query |
|
|
) |
|
|
final_result = postprocessor.postprocess(query, results) |
|
|
|
|
|
if query.metadata_filters: |
|
|
final_result.corpus.chunks = [ |
|
|
chunk for chunk in final_result.corpus.chunks |
|
|
if all(chunk.metadata.model_dump().get(k) == v for k, v in query.metadata_filters.items()) |
|
|
] |
|
|
final_result.scores = [chunk.metadata.similarity_score for chunk in final_result.corpus.chunks] |
|
|
logger.info(f"Applied metadata filters, retained {len(final_result.corpus.chunks)} chunks") |
|
|
|
|
|
logger.info(f"Query returned {len(final_result.corpus.chunks)} chunks after post-processing") |
|
|
return final_result |
|
|
except Exception as e: |
|
|
logger.error(f"Query failed: {str(e)}") |
|
|
raise |
|
|
|
|
|
def query(self, query: Union[str, Query], corpus_id: Optional[str] = None, |
|
|
query_transforms: Optional[List] = None) -> RagResult: |
|
|
"""Synchronous wrapper for the async query method.""" |
|
|
return asyncio.run(self.query_async(query, corpus_id, query_transforms)) |