iLOVE2D's picture
Upload 2846 files
5374a2d verified
import json
import hashlib
from uuid import uuid4
from typing import List, Dict, Optional, Union, Any
from pydantic import Field
from llama_index.core.schema import ImageNode
from llama_index.core.schema import QueryBundle
from llama_index.core import Document as LlamaIndexDocument
from llama_index.core.schema import BaseNode, TextNode, RelatedNodeInfo
from llama_index.core.graph_stores.types import (
Relation,
EntityNode,
ChunkNode,
)
from evoagentx.core.base_config import BaseModule
from evoagentx.core.logging import logger
DEAFULT_EXCLUDED = ['file_name', 'file_type', 'file_size', 'page_count', 'creation_date',
'last_modified_date', 'language', 'word_count', 'custom_fields', 'hash_doc',
'graph_node', # for graph store
]
class DocumentMetadata(BaseModule):
"""
This class ensures type safety and validation for metadata associated with a document,
such as file information, creation date, and custom fields.
"""
file_name: Optional[str] = Field(default=None, description="The name of the document file, excluding the path.")
file_path: Optional[str] = Field(default=None, description="The file path or URL where the document is stored.")
file_type: Optional[str] = Field(default=None, description="The type of the document (e.g., '.pdf', '.docx', '.md', '.txt').")
file_size: Optional[int] = Field(default=None, description="The size of the document.")
page_count: Optional[int] = Field(default=None, description="The number of pages in the document, if applicable (e.g., for PDFs).")
creation_date: Optional[str] = Field(default=None, description="The creation date and time of the document.")
last_modified_date: Optional[str] = Field(default=None, description="The last modified date and time of the document.")
language: Optional[str] = Field(default=None, description="The primary language of the document (e.g., 'en', 'zh').")
word_count: Optional[int] = Field(default=None, description="The number of words in the document, calculated during initialization.")
custom_fields: Dict[str, Any] = Field(default_factory=dict, description="A dictionary for storing additional user-defined metadata.")
hash_doc: Optional[str] = Field(default=None, description="The hash code of this Document for deduplication")
class GraphNodeData(BaseModule):
# graph support
# Basic
label: Optional[str] = Field(default="entity", description="The label name of the 'LabelNode', 'EntityNode', 'Relation' in llama_index node.")
# Relation
node_class_name: Optional[str] = Field(default=None, description="The class name of the source llama_index node.")
properties: Optional[Dict] = Field(default_factory=dict, description="Represents all information from the Node.")
# Entity Node
node_name: Optional[str] = Field(default=None, description="Entity name of each node.")
source_id: Optional[str] = Field(default=None, description="Source node ID.")
target_id: Optional[str] = Field(default=None, description="Target node ID.")
# Chunk Node
text: Optional[str] = Field(default=None, description="The text stored in the ChunkNode.")
id_: Optional[str] = Field(default=None, description="ChunkNode id.")
class ChunkMetadata(DocumentMetadata):
"""
This class holds metadata for a chunk, including its relationship to the parent document,
chunking parameters, and retrieval-related information.
"""
doc_id: Optional[str] = Field(default=None, description="The unique identifier of the parent document.")
corpus_id: Optional[str] = Field(default=None, description="The unique identifier of the Corpus(Indexing).")
chunk_size: Optional[int] = Field(default=None, description="The size of the chunk in characters, if applicable.")
chunk_overlap: Optional[int] = Field(default=None, description="The number of overlapping characters between adjacent chunks.")
chunk_index: Optional[int] = Field(default=None, description="The index of the chunk within the parent document.")
chunking_strategy: Optional[str] = Field(default=None, description="The strategy used to create the chunk (e.g., 'simple', 'semantic', 'tree').")
similarity_score: Optional[float] = Field(default=None, description="Similarity score from retrieval.")
# graph support
graph_node: Optional[GraphNodeData] = Field(default=None, description="The properties of all types of graph nodes.")
# LongTermMemory
content: Optional[str] = Field(default=None, description="the content of the message, will be dumps by 'dumps' from json lib.")
memory_id: Optional[str] = Field(default=None, description="Unique identifier for memory entries.")
agent: Optional[str] = Field(default=None, description="The sender of the message.")
msg_type: Optional[str] = Field(default=None, description="The type of the message (e.g., 'request', 'response').")
prompt: Optional[Union[str, List[dict]]] = Field(default=None, description="The prompt used to generate the message.")
next_actions: Optional[List[str]] = Field(default=None, description="The following actions after the message.")
wf_task: Optional[str] = Field(default=None, description="The name of a task in the workflow.")
wf_task_desc: Optional[str] = Field(default=None, description="The description of a task in the workflow.")
message_id: Optional[str] = Field(default=None, description="Unique identifier for the message.")
action: Optional[str] = Field(default=None, description="the trigger of the message, normally set as the action name.")
wf_goal: Optional[str] = Field(default=None, description="the goal of the whole workflow.")
timestamp: Optional[str] = Field(default=None, description="the timestame of the message. ")
class IndexMetadata(BaseModule):
corpus_id: str = Field(..., description="Identifier for the corpus")
index_type: str = Field(..., description="Type of index (e.g., 'vector', 'graph', 'summary', 'tree')")
collection_name: Optional[str] = Field(default="default_collection", description="Vector store collection name or FAISS file path")
dimension: Optional[int] = Field(default=1536, description="Vector dimension")
vector_db_type: Optional[str] = Field(default=None, description="Vector database type (e.g., 'faiss', 'qdrant', 'chroma')")
graph_db_type: Optional[str] = Field(default=None, description="Graph database type (e.g., 'neo4j')")
embedding_model_name: Optional[str] = Field(default=None, description="")
date: Optional[str] = Field(default=None, description="Creation or last update date")
class Document(BaseModule):
"""A custom document class for managing documents in the RAG pipeline.
Attributes:
text (str): The full content of the document.
doc_id (str): Unique identifier for the document.
metadata (DocumentMetadata): Metadata including file info, creation date, etc.
source (str): Source of the document (e.g., file path or URL).
llama_doc (LlamaIndexDocument): Underlying LlamaIndex Document object.
"""
def __init__(
self,
text: str,
metadata: Optional[Union[Dict, DocumentMetadata]] = None,
embedding: Optional[List[float]] = None,
doc_id: Optional[str] = None,
excluded_embed_metadata_keys: List[str] = DEAFULT_EXCLUDED,
excluded_llm_metadata_keys: List[str] = DEAFULT_EXCLUDED,
relationships: Dict[str, RelatedNodeInfo] = {},
metadata_template: str = '{key}: {value}',
metadata_separator: str = '\n',
text_template: str = '{metadata_str}\n\n{content}'
):
metadata = (
DocumentMetadata.model_validate(metadata) if isinstance(metadata, dict) else metadata or DocumentMetadata()
)
super().__init__(
text=text.strip(),
doc_id=doc_id or str(uuid4()),
metadata=metadata,
embedding=embedding,
excluded_embed_metadata_keys=list(set(DEAFULT_EXCLUDED + excluded_embed_metadata_keys)),
excluded_llm_metadata_keys=list(set(DEAFULT_EXCLUDED + excluded_llm_metadata_keys)),
relationships=relationships,
metadata_template=metadata_template,
metadata_separator=metadata_separator,
text_template=text_template,
)
self.metadata.word_count = len(self.text.split())
def to_llama_document(self) -> LlamaIndexDocument:
"""Convert to LlamaIndex Document."""
return LlamaIndexDocument(
text=self.text,
metadata=self.metadata.model_dump(),
id_=self.doc_id,
embedding=self.embedding,
excluded_llm_metadata_keys=self.excluded_llm_metadata_keys,
excluded_embed_metadata_keys=self.excluded_embed_metadata_keys,
relationships=self.relationships,
metadata_template=self.metadata_template,
metadata_separator=self.metadata_separator,
text_template=self.text_template,
)
@classmethod
def from_llama_document(cls, llama_doc: LlamaIndexDocument) -> "Document":
"""Create Document from LlamaIndex Document."""
metadata = DocumentMetadata.model_validate(llama_doc.metadata)
return cls(
text=llama_doc.text,
metadata=metadata,
doc_id=llama_doc.id_,
embedding=llama_doc.embedding,
excluded_llm_metadata_keys=llama_doc.excluded_llm_metadata_keys,
excluded_embed_metadata_keys=llama_doc.excluded_llm_metadata_keys,
relationships=llama_doc.relationships,
metadata_template=llama_doc.metadata_template,
metadata_separator=llama_doc.metadata_separator,
text_template=llama_doc.text_template
)
def set_embedding(self, embedding: List[float]):
"""Set the embedding vector for the Document."""
self.embedding = embedding
def compute_hash(self) -> str:
"""Compute a hash of the document text for deduplication."""
return hashlib.sha256(self.text.encode()).hexdigest()
def get_fragment(self, max_length: int = 100) -> str:
"""Return a fragment of the document text."""
return (self.text[:max_length] + "...") if len(self.text) > max_length else self.text
def to_dict(self) -> Dict:
"""Convert document to dictionary for serialization."""
return {
"doc_id": self.doc_id,
"text": self.text,
"metadata": self.metadata.model_dump(),
"embedding": self.embedding,
"excluded_embed_metadata_keys": self.excluded_embed_metadata_keys,
"excluded_llm_metadata_keys": self.excluded_llm_metadata_keys,
"relationships": {str(k): v for k, v in self.relationships.items()},
"metadata_template": self.metadata_template,
"metadata_separator": self.metadata_separator,
"text_template": self.text_template,
}
def to_json(self, indent: int = 2) -> str:
"""Convert document to JSON string."""
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
def __str__(self) -> str:
return (
f"Document(id={self.doc_id}, embedding={self.embedding}, metadata={self.metadata.model_dump()}"
f"fragment={self.get_fragment(max_length=300)})"
)
def __repr__(self) -> str:
return (
f"Document(doc_id={self.doc_id}, embedding={self.embedding}, metadata={self.metadata.model_dump()},"
f"fragment={self.get_fragment(max_length=300)})"
)
class TextChunk(BaseModule):
"""A single chunk of a document for RAG processing.
Attributes:
text (str): The content of the chunk.
doc_id (str): ID of the parent document.
chunk_id (str): Unique identifier for the chunk.
metadata (ChunkMetadata): Metadata including chunk size, embedding, etc.
llama_node (BaseNode): Underlying LlamaIndex Node object.
"""
def __init__(
self,
text: str = "",
chunk_id: Optional[str] = None,
embedding: Optional[List[float]] = None,
start_char_idx: Optional[int] = None,
end_char_idx: Optional[int] = None,
excluded_embed_metadata_keys: List[str] = DEAFULT_EXCLUDED,
excluded_llm_metadata_keys: List[str] = DEAFULT_EXCLUDED,
text_template: str = '{metadata_str}\n\n{content}',
relationships: Dict[str, RelatedNodeInfo] = {},
metadata: Optional[Union[Dict, ChunkMetadata]] = None,
):
metadata = (
ChunkMetadata.model_validate(metadata) if isinstance(metadata, dict) else metadata or ChunkMetadata()
)
super().__init__(
text=text.strip(),
chunk_id=chunk_id or str(uuid4()),
embedding=embedding,
start_char_idx=start_char_idx,
end_char_idx=end_char_idx,
excluded_embed_metadata_keys=list(set(DEAFULT_EXCLUDED + excluded_embed_metadata_keys)),
excluded_llm_metadata_keys=list(set(DEAFULT_EXCLUDED + excluded_llm_metadata_keys)),
text_template=text_template,
relationships=relationships,
metadata=metadata,
)
self.metadata.word_count = len(self.text.split())
def to_llama_node(self) -> Union[TextNode, Relation, EntityNode, ChunkNode]:
"""Convert to LlamaIndex Node."""
relatiuonships = dict()
for k, v in self.relationships.items():
relatiuonships[k] = v if isinstance(v, RelatedNodeInfo) else RelatedNodeInfo.from_dict(v)
cls = TextNode
if self.metadata.graph_node is not None:
class_name = self.metadata.graph_node.node_class_name.lower()
if class_name == "relation":
cls = Relation(
label=self.metadata.graph_node.label,
source_id=self.metadata.graph_node.source_id,
target_id=self.metadata.graph_node.target_id,
properties={"metadata": json.dumps(self.metadata.graph_node.properties["metadata"])},
)
elif class_name == "entity":
cls = EntityNode(
label=self.metadata.graph_node.label,
embedding=self.embedding,
name=self.metadata.graph_node.node_name,
properties={"triplet_source_id": self.metadata.graph_node.properties["triplet_source_id"]}
)
# cls.triplet_source_id
else:
NotImplementedError()
return cls
else:
metadata = self.metadata.model_dump()
if "class_name" in metadata:
metadata.pop("class_name")
return cls(
text=self.text,
metadata=metadata,
id_=self.chunk_id,
embedding=self.embedding,
start_char_idx=self.start_char_idx,
end_char_idx=self.end_char_idx,
excluded_llm_metadata_keys=self.excluded_llm_metadata_keys,
excluded_embed_metadata_keys=self.excluded_embed_metadata_keys,
text_template=self.text_template,
relationships=relatiuonships
)
@classmethod
def from_llama_node(cls, node: Union[TextNode, Relation, EntityNode, ChunkNode]) -> "Chunk":
"""Create Chunk from LlamaIndex Node."""
if isinstance(node, TextNode):
return cls(
chunk_id=node.id_,
text=node.text,
metadata=ChunkMetadata.model_validate(node.metadata),
embedding=node.embedding,
start_char_idx=getattr(node, "start_char_idx", None),
end_char_idx=getattr(node, "end_char_idx", None),
excluded_embed_metadata_keys=node.excluded_embed_metadata_keys,
excluded_llm_metadata_keys=node.excluded_llm_metadata_keys,
text_template=node.text_template,
relationships=node.relationships
)
elif isinstance(node, Relation):
if 'class_name' in node.properties:
node.properties.pop('class_name')
properties = node.properties if isinstance(node.properties, dict) else node.properties.model_dump()
graph_node = GraphNodeData(
node_class_name="relation",
label=node.label,
source_id=node.source_id,
target_id=node.target_id,
properties={"metadata": properties}
)
metadata= {"graph_node": graph_node}
return cls(
metadata=ChunkMetadata.model_validate(metadata)
)
elif isinstance(node, EntityNode):
graph_node = GraphNodeData(
node_class_name="entity",
label=node.label,
node_name=node.name,
properties={"triplet_source_id": node.properties["triplet_source_id"]}
# triplet_source_id=node.triplet_source_id
)
metadata= {"graph_node": graph_node}
return cls(
embedding=node.embedding,
metadata=ChunkMetadata.model_validate(metadata)
)
elif isinstance(node, ChunkNode):
graph_node = GraphNodeData(
node_class_name="chunk",
text=node.text,
properties=node.properties,
id_=node.id_,
)
metadata= {"graph_node": graph_node}
return cls(
embedding=node.embedding,
metadata=ChunkMetadata.model_validate(metadata)
)
def get_fragment(self, max_length: int = 100) -> str:
"""Return a fragment of the chunk text."""
return (self.text[:max_length] + "...") if len(self.text) > max_length else self.text
def to_dict(self) -> Dict:
"""Convert chunk to dictionary for serialization."""
relationships = dict()
for k, v in self.relationships.items():
relationships[k] = v.to_dict() if isinstance(v, RelatedNodeInfo) else v
self.relationships = relationships
# return {"chunk_id": self.chunk_id,"text": self.text,"metadata": self.metadata.model_dump(),"embedding": self.embedding,"start_char_idx": self.start_char_idx,"end_char_idx": self.end_char_idx,"excluded_embed_metadata_keys": self.excluded_embed_metadata_keys,"excluded_llm_metadata_keys": self.excluded_llm_metadata_keys,"relationships": relatiuonships}
return self.model_dump()
def to_json(self, indent: int = 2) -> str:
"""Convert chunk to JSON string."""
return self.model_dump_json(indent=indent).strip()
def __str__(self) -> str:
return (
f"Chunk(id={self.chunk_id}, text={self.text}, "
f"chunking_strategy={self.metadata.chunking_strategy}, "
f"embedding={self.embedding}), "
f"start_char_idx={self.start_char_idx}, "
f"end_char_idx={self.end_char_idx}, "
f"excluded_embed_metadata_keys={self.excluded_embed_metadata_keys},"
f"excluded_llm_metadata_keys={self.excluded_llm_metadata_keys},"
f"text_template={self.text_template},"
f"metadata={self.metadata.model_dump()}"
)
def __repr__(self) -> str:
return (
f"Chunk(id={self.chunk_id}, text={self.text}, "
f"chunking_strategy={self.metadata.chunking_strategy}, "
f"embedding={self.embedding}), "
f"start_char_idx={self.start_char_idx}, "
f"end_char_idx={self.end_char_idx}, "
f"excluded_embed_metadata_keys={self.excluded_embed_metadata_keys},"
f"excluded_llm_metadata_keys={self.excluded_llm_metadata_keys},"
f"text_template={self.text_template},"
f"metadata={self.metadata.model_dump()}"
)
# Backward compatibility alias
Chunk = TextChunk
class ImageChunk(BaseModule):
"""An image-based chunk with lazy loading.
Attributes:
image_path (str): Path to the image file.
image_mimetype (Optional[str]): MIME type of the image.
chunk_id (str): Unique identifier for the chunk.
metadata (ChunkMetadata): Metadata including embedding, similarity scores, etc.
"""
def __init__(
self,
image_path: str,
image_mimetype: Optional[str] = None,
chunk_id: Optional[str] = None,
embedding: Optional[List[float]] = None,
excluded_embed_metadata_keys: List[str] = DEAFULT_EXCLUDED,
excluded_llm_metadata_keys: List[str] = DEAFULT_EXCLUDED,
text_template: str = '{metadata_str}\n\n{content}',
relationships: Dict[str, RelatedNodeInfo] = {},
metadata: Optional[Union[Dict, ChunkMetadata]] = None,
):
metadata = (
ChunkMetadata.model_validate(metadata) if isinstance(metadata, dict) else metadata or ChunkMetadata()
)
super().__init__(
image_path=image_path,
image_mimetype=image_mimetype,
chunk_id=chunk_id or str(uuid4()),
embedding=embedding,
excluded_embed_metadata_keys=list(set(DEAFULT_EXCLUDED + excluded_embed_metadata_keys)),
excluded_llm_metadata_keys=list(set(DEAFULT_EXCLUDED + excluded_llm_metadata_keys)),
text_template=text_template,
relationships=relationships,
metadata=metadata,
)
# Private cache for lazy loading
self._cached_image = None
def get_image(self):
"""Load PIL Image on-demand with caching."""
if self._cached_image is None:
from PIL import Image
try:
logger.debug(f"Loading image from path: {self.image_path}")
if not self.image_path:
logger.error("Image path is None or empty!")
return None
self._cached_image = Image.open(self.image_path)
logger.debug(f"Successfully loaded image from {self.image_path}")
except Exception as e:
logger.error(f"Failed to load image from {self.image_path}: {str(e)}")
return None
return self._cached_image
def get_image_bytes(self, format: str = "PNG") -> Optional[bytes]:
"""Get image as bytes for embedding or processing."""
import io
image = self.get_image()
if image is None:
return None
img_bytes = io.BytesIO()
image.save(img_bytes, format=format)
return img_bytes.getvalue()
def to_llama_node(self) -> ImageNode:
"""Convert to LlamaIndex ImageNode with on-demand image loading."""
relationships = dict()
for k, v in self.relationships.items():
relationships[k] = v if isinstance(v, RelatedNodeInfo) else RelatedNodeInfo.from_dict(v)
return ImageNode(
image=None,
image_path=self.image_path,
image_mimetype=self.image_mimetype,
metadata=self.metadata.model_dump(),
id_=self.chunk_id,
embedding=self.embedding,
excluded_llm_metadata_keys=self.excluded_llm_metadata_keys,
excluded_embed_metadata_keys=self.excluded_embed_metadata_keys,
text_template=self.text_template,
relationships=relationships
)
@classmethod
def from_llama_node(cls, node: ImageNode) -> "ImageChunk":
"""Create ImageChunk from LlamaIndex ImageNode."""
metadata = ChunkMetadata.model_validate(node.metadata)
logger.debug(f"Creating ImageChunk from ImageNode - image_path: {node.image_path}")
return cls(
chunk_id=node.id_,
image_path=node.image_path,
image_mimetype=node.image_mimetype,
metadata=metadata,
embedding=node.embedding,
excluded_embed_metadata_keys=node.excluded_embed_metadata_keys,
excluded_llm_metadata_keys=node.excluded_llm_metadata_keys,
text_template=node.text_template,
relationships=node.relationships
)
class Corpus(BaseModule):
"""A generic collection of document chunks for RAG processing.
Attributes:
corpus_id (str): The unique id for corpus.
chunks (List[Union[TextChunk, ImageChunk]]): List of chunks in the corpus.
chunk_index (Dict[str, Union[TextChunk, ImageChunk]]): Index of chunks by chunk_id for fast lookup.
metadata (Optional[IndexMetadata]): the metadata for this corpus.
"""
def __init__(self, chunks: Optional[List[Union[TextChunk, ImageChunk]]] = None, corpus_id: Optional[str] = None,
metadata: Optional[Union[IndexMetadata, Dict]] = None):
corpus_id = uuid4() if corpus_id is None else corpus_id
chunks = [] if chunks is None else chunks
chunk_index = {} if chunks is None else {chunk.chunk_id: chunk for chunk in chunks}
if metadata is None:
metadata = {}
elif isinstance(metadata, IndexMetadata):
metadata = metadata.model_dump()
super().__init__(
corpus_id=corpus_id,
chunks=chunks,
chunk_index=chunk_index,
metadata=metadata
)
def to_llama_nodes(self) -> List[BaseNode]:
"""Convert to list of LlamaIndex Nodes."""
if not self.chunks:
self.chunks = []
return [chunk.to_llama_node() for chunk in self.chunks]
@classmethod
def from_llama_nodes(cls, nodes: List[BaseNode]) -> "Corpus":
"""Create a Corpus from a list of LlamaIndex Nodes.
Args:
nodes (List[BaseNode]): The LlamaIndex Nodes to convert.
Returns:
Corpus: A new Corpus instance.
"""
chunks = []
for node in nodes:
if isinstance(node, ImageNode):
chunks.append(ImageChunk.from_llama_node(node))
else:
# Default to TextChunk for TextNode and other node types
chunks.append(TextChunk.from_llama_node(node))
return cls(chunks)
def add_chunk(self, batch_chunk: Union[TextChunk, ImageChunk, List[Union[TextChunk, ImageChunk]]]):
"""Add a batch chunk to the corpus and update index."""
if not isinstance(batch_chunk, list):
batch_chunk = [batch_chunk]
for chunk in batch_chunk:
self.chunks.append(chunk)
self.chunk_index[chunk.chunk_id] = chunk
def get_chunk(self, chunk_id: str) -> Optional[Union[TextChunk, ImageChunk]]:
"""Retrieve a chunk by its ID."""
return self.chunk_index.get(chunk_id)
def remove_chunk(self, chunk_id: str):
"""Remove a chunk by its ID."""
self.chunks = [chunk for chunk in self.chunks if chunk.chunk_id != chunk_id]
self.chunk_index.pop(chunk_id, None)
def filter_by_doc_id(self, doc_id: str) -> List[Union[TextChunk, ImageChunk]]:
"""Filter chunks by parent document ID."""
return [chunk for chunk in self.chunks if hasattr(chunk.metadata, 'doc_id') and chunk.metadata.doc_id == doc_id]
def filter_by_similarity(self, threshold: float) -> List[Union[TextChunk, ImageChunk]]:
"""Filter chunks by similarity score."""
return [chunk for chunk in self.chunks if chunk.metadata.similarity_score and chunk.metadata.similarity_score >= threshold]
def sort_by_similarity(self, reverse: bool = True) -> List[Union[TextChunk, ImageChunk]]:
"""Sort chunks by similarity score (descending by default)."""
return sorted(
[chunk for chunk in self.chunks if chunk.metadata.similarity_score is not None],
key=lambda x: x.metadata.similarity_score,
reverse=reverse
)
def to_dict(self, round_trip=False) -> Dict:
"""Convert corpus to dictionary for serialization."""
return [self.model_dump(round_trip=round_trip)]
def to_json(self, indent: int = 2, round_trip=True) -> str:
"""Convert corpus to JSON string."""
return json.dumps(self.to_dict(round_trip), indent=indent, ensure_ascii=False)
def to_jsonl(self, output_path: str, indent: int = 0):
with open(output_path, 'w', encoding='utf-8') as f:
for chunk in self.chunks:
json_str = chunk.to_json(indent=None)
if '\n' in json_str:
# Log warning if JSON contains newlines, which breaks JSONL format
print(f"Chunk {chunk.chunk_id} contains newlines in JSON, which may break JSONL format.")
f.write(json_str + '\n')
@classmethod
def from_jsonl(cls, input_path: str, corpus_id: Optional[str] = None) -> "Corpus":
chunks = []
with open(input_path, 'r', encoding='utf-8') as f:
for line in f:
chunk_dict = json.loads(line.strip())
metadata = ChunkMetadata.model_validate(chunk_dict["metadata"])
chunk = Chunk(
chunk_id=chunk_dict["chunk_id"],
text=chunk_dict["text"],
metadata=metadata,
embedding=chunk_dict["embedding"],
start_char_idx=chunk_dict["start_char_idx"],
end_char_idx=chunk_dict["end_char_idx"],
excluded_embed_metadata_keys=chunk_dict["excluded_embed_metadata_keys"],
excluded_llm_metadata_keys=chunk_dict["excluded_llm_metadata_keys"],
relationships={
k: RelatedNodeInfo(**v) for k, v in chunk_dict["relationships"].items()
}
)
chunks.append(chunk)
return cls(chunks=chunks, corpus_id=corpus_id)
def __str__(self) -> str:
stats = self.get_stats()
return (
f"Corpus(chunks={stats['chunk_count']}, unique_docs={stats['unique_docs']}, "
f"avg_word_count={stats['avg_word_count']:.1f}, strategies={stats['strategies']})"
)
def __repr__(self) -> str:
return f"Corpus(chunks={len(self.chunks)}, chunk_index_keys={list(self.chunk_index.keys())})"
def __len__(self) -> int:
return len(self.chunks)
def get_stats(self) -> Dict[str, Any]:
"""Get statistics about the corpus."""
if not self.chunks:
return {
'chunk_count': 0,
'unique_docs': 0,
'avg_word_count': 0.0,
'strategies': set()
}
# Count unique documents
unique_docs = set()
total_word_count = 0
strategies = set()
for chunk in self.chunks:
if hasattr(chunk.metadata, 'doc_id') and chunk.metadata.doc_id:
unique_docs.add(chunk.metadata.doc_id)
if hasattr(chunk.metadata, 'word_count') and chunk.metadata.word_count:
total_word_count += chunk.metadata.word_count
if hasattr(chunk.metadata, 'chunking_strategy') and chunk.metadata.chunking_strategy:
strategies.add(chunk.metadata.chunking_strategy)
avg_word_count = total_word_count / len(self.chunks) if self.chunks else 0.0
return {
'chunk_count': len(self.chunks),
'unique_docs': len(unique_docs),
'avg_word_count': avg_word_count,
'strategies': strategies
}
class Query(BaseModule):
"""Represents a retrieval query."""
query_str: str = Field(description="The query string.")
top_k: Optional[int] = Field(default=None, description="Number of top results to retrieve.")
custom_embedding_strs: Optional[List[str]] = Field(default=None, description="The List to store additional strings need to be embed with the query.")
similarity_cutoff: Optional[float] = Field(default=None, description="Minimum similarity score.")
keyword_filters: Optional[List[str]] = Field(default=None, description="Keywords to filter results.")
metadata_filters: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata filters.")
@property
def embedding_strs(self) -> List[str]:
"""Use custom embedding strs if specified, otherwise use query str."""
if self.custom_embedding_strs is None:
if len(self.query_str) == 0:
return []
return [self.query_str]
else:
return self.custom_embedding_strs
def to_QueryBundle(self):
return QueryBundle(
query_str=self.query_str,
custom_embedding_strs=self.custom_embedding_strs
)
class RagResult(BaseModule):
"""Represents a generic retrieval result."""
corpus: Corpus = Field(description="Retrieved chunks.")
scores: List[float] = Field(description="Similarity scores for each chunk.")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional result metadata.")
def get_top_chunks(self, limit: int = None) -> List[Union[TextChunk, ImageChunk]]:
"""Get top chunks sorted by similarity score."""
chunks = self.corpus.sort_by_similarity(reverse=True)
return chunks[:limit] if limit else chunks