|
|
import json |
|
|
import hashlib |
|
|
from uuid import uuid4 |
|
|
from typing import List, Dict, Optional, Union, Any |
|
|
from pydantic import Field |
|
|
|
|
|
from llama_index.core.schema import ImageNode |
|
|
from llama_index.core.schema import QueryBundle |
|
|
from llama_index.core import Document as LlamaIndexDocument |
|
|
from llama_index.core.schema import BaseNode, TextNode, RelatedNodeInfo |
|
|
from llama_index.core.graph_stores.types import ( |
|
|
Relation, |
|
|
EntityNode, |
|
|
ChunkNode, |
|
|
) |
|
|
|
|
|
from evoagentx.core.base_config import BaseModule |
|
|
from evoagentx.core.logging import logger |
|
|
|
|
|
DEAFULT_EXCLUDED = ['file_name', 'file_type', 'file_size', 'page_count', 'creation_date', |
|
|
'last_modified_date', 'language', 'word_count', 'custom_fields', 'hash_doc', |
|
|
'graph_node', |
|
|
] |
|
|
|
|
|
class DocumentMetadata(BaseModule): |
|
|
""" |
|
|
This class ensures type safety and validation for metadata associated with a document, |
|
|
such as file information, creation date, and custom fields. |
|
|
""" |
|
|
|
|
|
file_name: Optional[str] = Field(default=None, description="The name of the document file, excluding the path.") |
|
|
file_path: Optional[str] = Field(default=None, description="The file path or URL where the document is stored.") |
|
|
file_type: Optional[str] = Field(default=None, description="The type of the document (e.g., '.pdf', '.docx', '.md', '.txt').") |
|
|
file_size: Optional[int] = Field(default=None, description="The size of the document.") |
|
|
page_count: Optional[int] = Field(default=None, description="The number of pages in the document, if applicable (e.g., for PDFs).") |
|
|
creation_date: Optional[str] = Field(default=None, description="The creation date and time of the document.") |
|
|
last_modified_date: Optional[str] = Field(default=None, description="The last modified date and time of the document.") |
|
|
language: Optional[str] = Field(default=None, description="The primary language of the document (e.g., 'en', 'zh').") |
|
|
word_count: Optional[int] = Field(default=None, description="The number of words in the document, calculated during initialization.") |
|
|
custom_fields: Dict[str, Any] = Field(default_factory=dict, description="A dictionary for storing additional user-defined metadata.") |
|
|
hash_doc: Optional[str] = Field(default=None, description="The hash code of this Document for deduplication") |
|
|
|
|
|
|
|
|
class GraphNodeData(BaseModule): |
|
|
|
|
|
|
|
|
label: Optional[str] = Field(default="entity", description="The label name of the 'LabelNode', 'EntityNode', 'Relation' in llama_index node.") |
|
|
|
|
|
|
|
|
node_class_name: Optional[str] = Field(default=None, description="The class name of the source llama_index node.") |
|
|
properties: Optional[Dict] = Field(default_factory=dict, description="Represents all information from the Node.") |
|
|
|
|
|
|
|
|
node_name: Optional[str] = Field(default=None, description="Entity name of each node.") |
|
|
source_id: Optional[str] = Field(default=None, description="Source node ID.") |
|
|
target_id: Optional[str] = Field(default=None, description="Target node ID.") |
|
|
|
|
|
|
|
|
text: Optional[str] = Field(default=None, description="The text stored in the ChunkNode.") |
|
|
id_: Optional[str] = Field(default=None, description="ChunkNode id.") |
|
|
|
|
|
class ChunkMetadata(DocumentMetadata): |
|
|
""" |
|
|
This class holds metadata for a chunk, including its relationship to the parent document, |
|
|
chunking parameters, and retrieval-related information. |
|
|
""" |
|
|
|
|
|
doc_id: Optional[str] = Field(default=None, description="The unique identifier of the parent document.") |
|
|
corpus_id: Optional[str] = Field(default=None, description="The unique identifier of the Corpus(Indexing).") |
|
|
chunk_size: Optional[int] = Field(default=None, description="The size of the chunk in characters, if applicable.") |
|
|
chunk_overlap: Optional[int] = Field(default=None, description="The number of overlapping characters between adjacent chunks.") |
|
|
chunk_index: Optional[int] = Field(default=None, description="The index of the chunk within the parent document.") |
|
|
chunking_strategy: Optional[str] = Field(default=None, description="The strategy used to create the chunk (e.g., 'simple', 'semantic', 'tree').") |
|
|
similarity_score: Optional[float] = Field(default=None, description="Similarity score from retrieval.") |
|
|
|
|
|
graph_node: Optional[GraphNodeData] = Field(default=None, description="The properties of all types of graph nodes.") |
|
|
|
|
|
content: Optional[str] = Field(default=None, description="the content of the message, will be dumps by 'dumps' from json lib.") |
|
|
memory_id: Optional[str] = Field(default=None, description="Unique identifier for memory entries.") |
|
|
agent: Optional[str] = Field(default=None, description="The sender of the message.") |
|
|
msg_type: Optional[str] = Field(default=None, description="The type of the message (e.g., 'request', 'response').") |
|
|
prompt: Optional[Union[str, List[dict]]] = Field(default=None, description="The prompt used to generate the message.") |
|
|
next_actions: Optional[List[str]] = Field(default=None, description="The following actions after the message.") |
|
|
wf_task: Optional[str] = Field(default=None, description="The name of a task in the workflow.") |
|
|
wf_task_desc: Optional[str] = Field(default=None, description="The description of a task in the workflow.") |
|
|
message_id: Optional[str] = Field(default=None, description="Unique identifier for the message.") |
|
|
action: Optional[str] = Field(default=None, description="the trigger of the message, normally set as the action name.") |
|
|
wf_goal: Optional[str] = Field(default=None, description="the goal of the whole workflow.") |
|
|
timestamp: Optional[str] = Field(default=None, description="the timestame of the message. ") |
|
|
|
|
|
class IndexMetadata(BaseModule): |
|
|
corpus_id: str = Field(..., description="Identifier for the corpus") |
|
|
index_type: str = Field(..., description="Type of index (e.g., 'vector', 'graph', 'summary', 'tree')") |
|
|
collection_name: Optional[str] = Field(default="default_collection", description="Vector store collection name or FAISS file path") |
|
|
dimension: Optional[int] = Field(default=1536, description="Vector dimension") |
|
|
vector_db_type: Optional[str] = Field(default=None, description="Vector database type (e.g., 'faiss', 'qdrant', 'chroma')") |
|
|
graph_db_type: Optional[str] = Field(default=None, description="Graph database type (e.g., 'neo4j')") |
|
|
embedding_model_name: Optional[str] = Field(default=None, description="") |
|
|
date: Optional[str] = Field(default=None, description="Creation or last update date") |
|
|
|
|
|
|
|
|
class Document(BaseModule): |
|
|
"""A custom document class for managing documents in the RAG pipeline. |
|
|
|
|
|
Attributes: |
|
|
text (str): The full content of the document. |
|
|
doc_id (str): Unique identifier for the document. |
|
|
metadata (DocumentMetadata): Metadata including file info, creation date, etc. |
|
|
source (str): Source of the document (e.g., file path or URL). |
|
|
llama_doc (LlamaIndexDocument): Underlying LlamaIndex Document object. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
text: str, |
|
|
metadata: Optional[Union[Dict, DocumentMetadata]] = None, |
|
|
embedding: Optional[List[float]] = None, |
|
|
doc_id: Optional[str] = None, |
|
|
excluded_embed_metadata_keys: List[str] = DEAFULT_EXCLUDED, |
|
|
excluded_llm_metadata_keys: List[str] = DEAFULT_EXCLUDED, |
|
|
relationships: Dict[str, RelatedNodeInfo] = {}, |
|
|
metadata_template: str = '{key}: {value}', |
|
|
metadata_separator: str = '\n', |
|
|
text_template: str = '{metadata_str}\n\n{content}' |
|
|
): |
|
|
metadata = ( |
|
|
DocumentMetadata.model_validate(metadata) if isinstance(metadata, dict) else metadata or DocumentMetadata() |
|
|
) |
|
|
|
|
|
super().__init__( |
|
|
text=text.strip(), |
|
|
doc_id=doc_id or str(uuid4()), |
|
|
metadata=metadata, |
|
|
embedding=embedding, |
|
|
excluded_embed_metadata_keys=list(set(DEAFULT_EXCLUDED + excluded_embed_metadata_keys)), |
|
|
excluded_llm_metadata_keys=list(set(DEAFULT_EXCLUDED + excluded_llm_metadata_keys)), |
|
|
relationships=relationships, |
|
|
metadata_template=metadata_template, |
|
|
metadata_separator=metadata_separator, |
|
|
text_template=text_template, |
|
|
) |
|
|
self.metadata.word_count = len(self.text.split()) |
|
|
|
|
|
def to_llama_document(self) -> LlamaIndexDocument: |
|
|
"""Convert to LlamaIndex Document.""" |
|
|
return LlamaIndexDocument( |
|
|
text=self.text, |
|
|
metadata=self.metadata.model_dump(), |
|
|
id_=self.doc_id, |
|
|
embedding=self.embedding, |
|
|
excluded_llm_metadata_keys=self.excluded_llm_metadata_keys, |
|
|
excluded_embed_metadata_keys=self.excluded_embed_metadata_keys, |
|
|
relationships=self.relationships, |
|
|
metadata_template=self.metadata_template, |
|
|
metadata_separator=self.metadata_separator, |
|
|
text_template=self.text_template, |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def from_llama_document(cls, llama_doc: LlamaIndexDocument) -> "Document": |
|
|
"""Create Document from LlamaIndex Document.""" |
|
|
metadata = DocumentMetadata.model_validate(llama_doc.metadata) |
|
|
return cls( |
|
|
text=llama_doc.text, |
|
|
metadata=metadata, |
|
|
doc_id=llama_doc.id_, |
|
|
embedding=llama_doc.embedding, |
|
|
excluded_llm_metadata_keys=llama_doc.excluded_llm_metadata_keys, |
|
|
excluded_embed_metadata_keys=llama_doc.excluded_llm_metadata_keys, |
|
|
relationships=llama_doc.relationships, |
|
|
metadata_template=llama_doc.metadata_template, |
|
|
metadata_separator=llama_doc.metadata_separator, |
|
|
text_template=llama_doc.text_template |
|
|
) |
|
|
|
|
|
def set_embedding(self, embedding: List[float]): |
|
|
"""Set the embedding vector for the Document.""" |
|
|
self.embedding = embedding |
|
|
|
|
|
def compute_hash(self) -> str: |
|
|
"""Compute a hash of the document text for deduplication.""" |
|
|
return hashlib.sha256(self.text.encode()).hexdigest() |
|
|
|
|
|
def get_fragment(self, max_length: int = 100) -> str: |
|
|
"""Return a fragment of the document text.""" |
|
|
return (self.text[:max_length] + "...") if len(self.text) > max_length else self.text |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
"""Convert document to dictionary for serialization.""" |
|
|
return { |
|
|
"doc_id": self.doc_id, |
|
|
"text": self.text, |
|
|
"metadata": self.metadata.model_dump(), |
|
|
"embedding": self.embedding, |
|
|
"excluded_embed_metadata_keys": self.excluded_embed_metadata_keys, |
|
|
"excluded_llm_metadata_keys": self.excluded_llm_metadata_keys, |
|
|
"relationships": {str(k): v for k, v in self.relationships.items()}, |
|
|
"metadata_template": self.metadata_template, |
|
|
"metadata_separator": self.metadata_separator, |
|
|
"text_template": self.text_template, |
|
|
} |
|
|
|
|
|
def to_json(self, indent: int = 2) -> str: |
|
|
"""Convert document to JSON string.""" |
|
|
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False) |
|
|
|
|
|
def __str__(self) -> str: |
|
|
return ( |
|
|
f"Document(id={self.doc_id}, embedding={self.embedding}, metadata={self.metadata.model_dump()}" |
|
|
f"fragment={self.get_fragment(max_length=300)})" |
|
|
) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return ( |
|
|
f"Document(doc_id={self.doc_id}, embedding={self.embedding}, metadata={self.metadata.model_dump()}," |
|
|
f"fragment={self.get_fragment(max_length=300)})" |
|
|
) |
|
|
|
|
|
class TextChunk(BaseModule): |
|
|
"""A single chunk of a document for RAG processing. |
|
|
|
|
|
Attributes: |
|
|
text (str): The content of the chunk. |
|
|
doc_id (str): ID of the parent document. |
|
|
chunk_id (str): Unique identifier for the chunk. |
|
|
metadata (ChunkMetadata): Metadata including chunk size, embedding, etc. |
|
|
llama_node (BaseNode): Underlying LlamaIndex Node object. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
text: str = "", |
|
|
chunk_id: Optional[str] = None, |
|
|
embedding: Optional[List[float]] = None, |
|
|
start_char_idx: Optional[int] = None, |
|
|
end_char_idx: Optional[int] = None, |
|
|
excluded_embed_metadata_keys: List[str] = DEAFULT_EXCLUDED, |
|
|
excluded_llm_metadata_keys: List[str] = DEAFULT_EXCLUDED, |
|
|
text_template: str = '{metadata_str}\n\n{content}', |
|
|
relationships: Dict[str, RelatedNodeInfo] = {}, |
|
|
metadata: Optional[Union[Dict, ChunkMetadata]] = None, |
|
|
): |
|
|
metadata = ( |
|
|
ChunkMetadata.model_validate(metadata) if isinstance(metadata, dict) else metadata or ChunkMetadata() |
|
|
) |
|
|
super().__init__( |
|
|
text=text.strip(), |
|
|
chunk_id=chunk_id or str(uuid4()), |
|
|
embedding=embedding, |
|
|
start_char_idx=start_char_idx, |
|
|
end_char_idx=end_char_idx, |
|
|
excluded_embed_metadata_keys=list(set(DEAFULT_EXCLUDED + excluded_embed_metadata_keys)), |
|
|
excluded_llm_metadata_keys=list(set(DEAFULT_EXCLUDED + excluded_llm_metadata_keys)), |
|
|
text_template=text_template, |
|
|
relationships=relationships, |
|
|
metadata=metadata, |
|
|
) |
|
|
self.metadata.word_count = len(self.text.split()) |
|
|
|
|
|
def to_llama_node(self) -> Union[TextNode, Relation, EntityNode, ChunkNode]: |
|
|
"""Convert to LlamaIndex Node.""" |
|
|
relatiuonships = dict() |
|
|
for k, v in self.relationships.items(): |
|
|
relatiuonships[k] = v if isinstance(v, RelatedNodeInfo) else RelatedNodeInfo.from_dict(v) |
|
|
|
|
|
cls = TextNode |
|
|
if self.metadata.graph_node is not None: |
|
|
class_name = self.metadata.graph_node.node_class_name.lower() |
|
|
if class_name == "relation": |
|
|
cls = Relation( |
|
|
label=self.metadata.graph_node.label, |
|
|
source_id=self.metadata.graph_node.source_id, |
|
|
target_id=self.metadata.graph_node.target_id, |
|
|
properties={"metadata": json.dumps(self.metadata.graph_node.properties["metadata"])}, |
|
|
) |
|
|
|
|
|
elif class_name == "entity": |
|
|
cls = EntityNode( |
|
|
label=self.metadata.graph_node.label, |
|
|
embedding=self.embedding, |
|
|
name=self.metadata.graph_node.node_name, |
|
|
properties={"triplet_source_id": self.metadata.graph_node.properties["triplet_source_id"]} |
|
|
) |
|
|
|
|
|
|
|
|
else: |
|
|
NotImplementedError() |
|
|
return cls |
|
|
else: |
|
|
metadata = self.metadata.model_dump() |
|
|
if "class_name" in metadata: |
|
|
metadata.pop("class_name") |
|
|
|
|
|
return cls( |
|
|
text=self.text, |
|
|
metadata=metadata, |
|
|
id_=self.chunk_id, |
|
|
embedding=self.embedding, |
|
|
start_char_idx=self.start_char_idx, |
|
|
end_char_idx=self.end_char_idx, |
|
|
excluded_llm_metadata_keys=self.excluded_llm_metadata_keys, |
|
|
excluded_embed_metadata_keys=self.excluded_embed_metadata_keys, |
|
|
text_template=self.text_template, |
|
|
relationships=relatiuonships |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def from_llama_node(cls, node: Union[TextNode, Relation, EntityNode, ChunkNode]) -> "Chunk": |
|
|
"""Create Chunk from LlamaIndex Node.""" |
|
|
|
|
|
if isinstance(node, TextNode): |
|
|
return cls( |
|
|
chunk_id=node.id_, |
|
|
text=node.text, |
|
|
metadata=ChunkMetadata.model_validate(node.metadata), |
|
|
embedding=node.embedding, |
|
|
start_char_idx=getattr(node, "start_char_idx", None), |
|
|
end_char_idx=getattr(node, "end_char_idx", None), |
|
|
excluded_embed_metadata_keys=node.excluded_embed_metadata_keys, |
|
|
excluded_llm_metadata_keys=node.excluded_llm_metadata_keys, |
|
|
text_template=node.text_template, |
|
|
relationships=node.relationships |
|
|
) |
|
|
|
|
|
elif isinstance(node, Relation): |
|
|
if 'class_name' in node.properties: |
|
|
node.properties.pop('class_name') |
|
|
properties = node.properties if isinstance(node.properties, dict) else node.properties.model_dump() |
|
|
graph_node = GraphNodeData( |
|
|
node_class_name="relation", |
|
|
label=node.label, |
|
|
source_id=node.source_id, |
|
|
target_id=node.target_id, |
|
|
properties={"metadata": properties} |
|
|
) |
|
|
metadata= {"graph_node": graph_node} |
|
|
return cls( |
|
|
metadata=ChunkMetadata.model_validate(metadata) |
|
|
) |
|
|
|
|
|
elif isinstance(node, EntityNode): |
|
|
graph_node = GraphNodeData( |
|
|
node_class_name="entity", |
|
|
label=node.label, |
|
|
node_name=node.name, |
|
|
properties={"triplet_source_id": node.properties["triplet_source_id"]} |
|
|
|
|
|
) |
|
|
metadata= {"graph_node": graph_node} |
|
|
return cls( |
|
|
embedding=node.embedding, |
|
|
metadata=ChunkMetadata.model_validate(metadata) |
|
|
) |
|
|
|
|
|
elif isinstance(node, ChunkNode): |
|
|
graph_node = GraphNodeData( |
|
|
node_class_name="chunk", |
|
|
text=node.text, |
|
|
properties=node.properties, |
|
|
id_=node.id_, |
|
|
) |
|
|
metadata= {"graph_node": graph_node} |
|
|
return cls( |
|
|
embedding=node.embedding, |
|
|
metadata=ChunkMetadata.model_validate(metadata) |
|
|
) |
|
|
|
|
|
def get_fragment(self, max_length: int = 100) -> str: |
|
|
"""Return a fragment of the chunk text.""" |
|
|
return (self.text[:max_length] + "...") if len(self.text) > max_length else self.text |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
"""Convert chunk to dictionary for serialization.""" |
|
|
relationships = dict() |
|
|
for k, v in self.relationships.items(): |
|
|
relationships[k] = v.to_dict() if isinstance(v, RelatedNodeInfo) else v |
|
|
self.relationships = relationships |
|
|
|
|
|
return self.model_dump() |
|
|
|
|
|
def to_json(self, indent: int = 2) -> str: |
|
|
"""Convert chunk to JSON string.""" |
|
|
return self.model_dump_json(indent=indent).strip() |
|
|
|
|
|
def __str__(self) -> str: |
|
|
return ( |
|
|
f"Chunk(id={self.chunk_id}, text={self.text}, " |
|
|
f"chunking_strategy={self.metadata.chunking_strategy}, " |
|
|
f"embedding={self.embedding}), " |
|
|
f"start_char_idx={self.start_char_idx}, " |
|
|
f"end_char_idx={self.end_char_idx}, " |
|
|
f"excluded_embed_metadata_keys={self.excluded_embed_metadata_keys}," |
|
|
f"excluded_llm_metadata_keys={self.excluded_llm_metadata_keys}," |
|
|
f"text_template={self.text_template}," |
|
|
f"metadata={self.metadata.model_dump()}" |
|
|
) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return ( |
|
|
f"Chunk(id={self.chunk_id}, text={self.text}, " |
|
|
f"chunking_strategy={self.metadata.chunking_strategy}, " |
|
|
f"embedding={self.embedding}), " |
|
|
f"start_char_idx={self.start_char_idx}, " |
|
|
f"end_char_idx={self.end_char_idx}, " |
|
|
f"excluded_embed_metadata_keys={self.excluded_embed_metadata_keys}," |
|
|
f"excluded_llm_metadata_keys={self.excluded_llm_metadata_keys}," |
|
|
f"text_template={self.text_template}," |
|
|
f"metadata={self.metadata.model_dump()}" |
|
|
) |
|
|
|
|
|
|
|
|
Chunk = TextChunk |
|
|
|
|
|
class ImageChunk(BaseModule): |
|
|
"""An image-based chunk with lazy loading. |
|
|
|
|
|
Attributes: |
|
|
image_path (str): Path to the image file. |
|
|
image_mimetype (Optional[str]): MIME type of the image. |
|
|
chunk_id (str): Unique identifier for the chunk. |
|
|
metadata (ChunkMetadata): Metadata including embedding, similarity scores, etc. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
image_path: str, |
|
|
image_mimetype: Optional[str] = None, |
|
|
chunk_id: Optional[str] = None, |
|
|
embedding: Optional[List[float]] = None, |
|
|
excluded_embed_metadata_keys: List[str] = DEAFULT_EXCLUDED, |
|
|
excluded_llm_metadata_keys: List[str] = DEAFULT_EXCLUDED, |
|
|
text_template: str = '{metadata_str}\n\n{content}', |
|
|
relationships: Dict[str, RelatedNodeInfo] = {}, |
|
|
metadata: Optional[Union[Dict, ChunkMetadata]] = None, |
|
|
): |
|
|
metadata = ( |
|
|
ChunkMetadata.model_validate(metadata) if isinstance(metadata, dict) else metadata or ChunkMetadata() |
|
|
) |
|
|
super().__init__( |
|
|
image_path=image_path, |
|
|
image_mimetype=image_mimetype, |
|
|
chunk_id=chunk_id or str(uuid4()), |
|
|
embedding=embedding, |
|
|
excluded_embed_metadata_keys=list(set(DEAFULT_EXCLUDED + excluded_embed_metadata_keys)), |
|
|
excluded_llm_metadata_keys=list(set(DEAFULT_EXCLUDED + excluded_llm_metadata_keys)), |
|
|
text_template=text_template, |
|
|
relationships=relationships, |
|
|
metadata=metadata, |
|
|
) |
|
|
|
|
|
|
|
|
self._cached_image = None |
|
|
|
|
|
def get_image(self): |
|
|
"""Load PIL Image on-demand with caching.""" |
|
|
if self._cached_image is None: |
|
|
from PIL import Image |
|
|
try: |
|
|
logger.debug(f"Loading image from path: {self.image_path}") |
|
|
if not self.image_path: |
|
|
logger.error("Image path is None or empty!") |
|
|
return None |
|
|
self._cached_image = Image.open(self.image_path) |
|
|
logger.debug(f"Successfully loaded image from {self.image_path}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load image from {self.image_path}: {str(e)}") |
|
|
return None |
|
|
return self._cached_image |
|
|
|
|
|
def get_image_bytes(self, format: str = "PNG") -> Optional[bytes]: |
|
|
"""Get image as bytes for embedding or processing.""" |
|
|
import io |
|
|
image = self.get_image() |
|
|
if image is None: |
|
|
return None |
|
|
|
|
|
img_bytes = io.BytesIO() |
|
|
image.save(img_bytes, format=format) |
|
|
return img_bytes.getvalue() |
|
|
|
|
|
def to_llama_node(self) -> ImageNode: |
|
|
"""Convert to LlamaIndex ImageNode with on-demand image loading.""" |
|
|
relationships = dict() |
|
|
for k, v in self.relationships.items(): |
|
|
relationships[k] = v if isinstance(v, RelatedNodeInfo) else RelatedNodeInfo.from_dict(v) |
|
|
|
|
|
return ImageNode( |
|
|
image=None, |
|
|
image_path=self.image_path, |
|
|
image_mimetype=self.image_mimetype, |
|
|
metadata=self.metadata.model_dump(), |
|
|
id_=self.chunk_id, |
|
|
embedding=self.embedding, |
|
|
excluded_llm_metadata_keys=self.excluded_llm_metadata_keys, |
|
|
excluded_embed_metadata_keys=self.excluded_embed_metadata_keys, |
|
|
text_template=self.text_template, |
|
|
relationships=relationships |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def from_llama_node(cls, node: ImageNode) -> "ImageChunk": |
|
|
"""Create ImageChunk from LlamaIndex ImageNode.""" |
|
|
metadata = ChunkMetadata.model_validate(node.metadata) |
|
|
|
|
|
logger.debug(f"Creating ImageChunk from ImageNode - image_path: {node.image_path}") |
|
|
|
|
|
return cls( |
|
|
chunk_id=node.id_, |
|
|
image_path=node.image_path, |
|
|
image_mimetype=node.image_mimetype, |
|
|
metadata=metadata, |
|
|
embedding=node.embedding, |
|
|
excluded_embed_metadata_keys=node.excluded_embed_metadata_keys, |
|
|
excluded_llm_metadata_keys=node.excluded_llm_metadata_keys, |
|
|
text_template=node.text_template, |
|
|
relationships=node.relationships |
|
|
) |
|
|
|
|
|
class Corpus(BaseModule): |
|
|
"""A generic collection of document chunks for RAG processing. |
|
|
|
|
|
Attributes: |
|
|
corpus_id (str): The unique id for corpus. |
|
|
chunks (List[Union[TextChunk, ImageChunk]]): List of chunks in the corpus. |
|
|
chunk_index (Dict[str, Union[TextChunk, ImageChunk]]): Index of chunks by chunk_id for fast lookup. |
|
|
metadata (Optional[IndexMetadata]): the metadata for this corpus. |
|
|
""" |
|
|
|
|
|
def __init__(self, chunks: Optional[List[Union[TextChunk, ImageChunk]]] = None, corpus_id: Optional[str] = None, |
|
|
metadata: Optional[Union[IndexMetadata, Dict]] = None): |
|
|
corpus_id = uuid4() if corpus_id is None else corpus_id |
|
|
chunks = [] if chunks is None else chunks |
|
|
chunk_index = {} if chunks is None else {chunk.chunk_id: chunk for chunk in chunks} |
|
|
|
|
|
if metadata is None: |
|
|
metadata = {} |
|
|
elif isinstance(metadata, IndexMetadata): |
|
|
metadata = metadata.model_dump() |
|
|
super().__init__( |
|
|
corpus_id=corpus_id, |
|
|
chunks=chunks, |
|
|
chunk_index=chunk_index, |
|
|
metadata=metadata |
|
|
) |
|
|
|
|
|
def to_llama_nodes(self) -> List[BaseNode]: |
|
|
"""Convert to list of LlamaIndex Nodes.""" |
|
|
if not self.chunks: |
|
|
self.chunks = [] |
|
|
return [chunk.to_llama_node() for chunk in self.chunks] |
|
|
|
|
|
@classmethod |
|
|
def from_llama_nodes(cls, nodes: List[BaseNode]) -> "Corpus": |
|
|
"""Create a Corpus from a list of LlamaIndex Nodes. |
|
|
|
|
|
Args: |
|
|
nodes (List[BaseNode]): The LlamaIndex Nodes to convert. |
|
|
|
|
|
Returns: |
|
|
Corpus: A new Corpus instance. |
|
|
""" |
|
|
chunks = [] |
|
|
for node in nodes: |
|
|
if isinstance(node, ImageNode): |
|
|
chunks.append(ImageChunk.from_llama_node(node)) |
|
|
else: |
|
|
|
|
|
chunks.append(TextChunk.from_llama_node(node)) |
|
|
return cls(chunks) |
|
|
|
|
|
def add_chunk(self, batch_chunk: Union[TextChunk, ImageChunk, List[Union[TextChunk, ImageChunk]]]): |
|
|
"""Add a batch chunk to the corpus and update index.""" |
|
|
if not isinstance(batch_chunk, list): |
|
|
batch_chunk = [batch_chunk] |
|
|
|
|
|
for chunk in batch_chunk: |
|
|
self.chunks.append(chunk) |
|
|
self.chunk_index[chunk.chunk_id] = chunk |
|
|
|
|
|
def get_chunk(self, chunk_id: str) -> Optional[Union[TextChunk, ImageChunk]]: |
|
|
"""Retrieve a chunk by its ID.""" |
|
|
return self.chunk_index.get(chunk_id) |
|
|
|
|
|
def remove_chunk(self, chunk_id: str): |
|
|
"""Remove a chunk by its ID.""" |
|
|
self.chunks = [chunk for chunk in self.chunks if chunk.chunk_id != chunk_id] |
|
|
self.chunk_index.pop(chunk_id, None) |
|
|
|
|
|
def filter_by_doc_id(self, doc_id: str) -> List[Union[TextChunk, ImageChunk]]: |
|
|
"""Filter chunks by parent document ID.""" |
|
|
return [chunk for chunk in self.chunks if hasattr(chunk.metadata, 'doc_id') and chunk.metadata.doc_id == doc_id] |
|
|
|
|
|
def filter_by_similarity(self, threshold: float) -> List[Union[TextChunk, ImageChunk]]: |
|
|
"""Filter chunks by similarity score.""" |
|
|
return [chunk for chunk in self.chunks if chunk.metadata.similarity_score and chunk.metadata.similarity_score >= threshold] |
|
|
|
|
|
def sort_by_similarity(self, reverse: bool = True) -> List[Union[TextChunk, ImageChunk]]: |
|
|
"""Sort chunks by similarity score (descending by default).""" |
|
|
return sorted( |
|
|
[chunk for chunk in self.chunks if chunk.metadata.similarity_score is not None], |
|
|
key=lambda x: x.metadata.similarity_score, |
|
|
reverse=reverse |
|
|
) |
|
|
|
|
|
def to_dict(self, round_trip=False) -> Dict: |
|
|
"""Convert corpus to dictionary for serialization.""" |
|
|
return [self.model_dump(round_trip=round_trip)] |
|
|
|
|
|
def to_json(self, indent: int = 2, round_trip=True) -> str: |
|
|
"""Convert corpus to JSON string.""" |
|
|
return json.dumps(self.to_dict(round_trip), indent=indent, ensure_ascii=False) |
|
|
|
|
|
def to_jsonl(self, output_path: str, indent: int = 0): |
|
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
|
for chunk in self.chunks: |
|
|
json_str = chunk.to_json(indent=None) |
|
|
if '\n' in json_str: |
|
|
|
|
|
print(f"Chunk {chunk.chunk_id} contains newlines in JSON, which may break JSONL format.") |
|
|
f.write(json_str + '\n') |
|
|
|
|
|
@classmethod |
|
|
def from_jsonl(cls, input_path: str, corpus_id: Optional[str] = None) -> "Corpus": |
|
|
chunks = [] |
|
|
with open(input_path, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
chunk_dict = json.loads(line.strip()) |
|
|
metadata = ChunkMetadata.model_validate(chunk_dict["metadata"]) |
|
|
chunk = Chunk( |
|
|
chunk_id=chunk_dict["chunk_id"], |
|
|
text=chunk_dict["text"], |
|
|
metadata=metadata, |
|
|
embedding=chunk_dict["embedding"], |
|
|
start_char_idx=chunk_dict["start_char_idx"], |
|
|
end_char_idx=chunk_dict["end_char_idx"], |
|
|
excluded_embed_metadata_keys=chunk_dict["excluded_embed_metadata_keys"], |
|
|
excluded_llm_metadata_keys=chunk_dict["excluded_llm_metadata_keys"], |
|
|
relationships={ |
|
|
k: RelatedNodeInfo(**v) for k, v in chunk_dict["relationships"].items() |
|
|
} |
|
|
) |
|
|
chunks.append(chunk) |
|
|
return cls(chunks=chunks, corpus_id=corpus_id) |
|
|
|
|
|
def __str__(self) -> str: |
|
|
stats = self.get_stats() |
|
|
return ( |
|
|
f"Corpus(chunks={stats['chunk_count']}, unique_docs={stats['unique_docs']}, " |
|
|
f"avg_word_count={stats['avg_word_count']:.1f}, strategies={stats['strategies']})" |
|
|
) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"Corpus(chunks={len(self.chunks)}, chunk_index_keys={list(self.chunk_index.keys())})" |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.chunks) |
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
|
"""Get statistics about the corpus.""" |
|
|
if not self.chunks: |
|
|
return { |
|
|
'chunk_count': 0, |
|
|
'unique_docs': 0, |
|
|
'avg_word_count': 0.0, |
|
|
'strategies': set() |
|
|
} |
|
|
|
|
|
|
|
|
unique_docs = set() |
|
|
total_word_count = 0 |
|
|
strategies = set() |
|
|
|
|
|
for chunk in self.chunks: |
|
|
if hasattr(chunk.metadata, 'doc_id') and chunk.metadata.doc_id: |
|
|
unique_docs.add(chunk.metadata.doc_id) |
|
|
if hasattr(chunk.metadata, 'word_count') and chunk.metadata.word_count: |
|
|
total_word_count += chunk.metadata.word_count |
|
|
if hasattr(chunk.metadata, 'chunking_strategy') and chunk.metadata.chunking_strategy: |
|
|
strategies.add(chunk.metadata.chunking_strategy) |
|
|
|
|
|
avg_word_count = total_word_count / len(self.chunks) if self.chunks else 0.0 |
|
|
|
|
|
return { |
|
|
'chunk_count': len(self.chunks), |
|
|
'unique_docs': len(unique_docs), |
|
|
'avg_word_count': avg_word_count, |
|
|
'strategies': strategies |
|
|
} |
|
|
|
|
|
|
|
|
class Query(BaseModule): |
|
|
"""Represents a retrieval query.""" |
|
|
|
|
|
query_str: str = Field(description="The query string.") |
|
|
top_k: Optional[int] = Field(default=None, description="Number of top results to retrieve.") |
|
|
custom_embedding_strs: Optional[List[str]] = Field(default=None, description="The List to store additional strings need to be embed with the query.") |
|
|
similarity_cutoff: Optional[float] = Field(default=None, description="Minimum similarity score.") |
|
|
keyword_filters: Optional[List[str]] = Field(default=None, description="Keywords to filter results.") |
|
|
metadata_filters: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata filters.") |
|
|
|
|
|
@property |
|
|
def embedding_strs(self) -> List[str]: |
|
|
"""Use custom embedding strs if specified, otherwise use query str.""" |
|
|
if self.custom_embedding_strs is None: |
|
|
if len(self.query_str) == 0: |
|
|
return [] |
|
|
return [self.query_str] |
|
|
else: |
|
|
return self.custom_embedding_strs |
|
|
|
|
|
def to_QueryBundle(self): |
|
|
return QueryBundle( |
|
|
query_str=self.query_str, |
|
|
custom_embedding_strs=self.custom_embedding_strs |
|
|
) |
|
|
|
|
|
class RagResult(BaseModule): |
|
|
"""Represents a generic retrieval result.""" |
|
|
|
|
|
corpus: Corpus = Field(description="Retrieved chunks.") |
|
|
scores: List[float] = Field(description="Similarity scores for each chunk.") |
|
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional result metadata.") |
|
|
|
|
|
def get_top_chunks(self, limit: int = None) -> List[Union[TextChunk, ImageChunk]]: |
|
|
"""Get top chunks sorted by similarity score.""" |
|
|
chunks = self.corpus.sort_by_similarity(reverse=True) |
|
|
return chunks[:limit] if limit else chunks |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|