AI_Toolkit / src /core /SemanticChunker.py
NavyDevilDoc's picture
Upload 10 files
c0f31c1 verified
"""
SemanticChunker.py
A module for semantic-aware text chunking using embeddings and similarity metrics.
"""
import logging
from typing import List, Optional, Any
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from langchain_core.documents import Document
# FIXED IMPORT: Updated for LangChain v0.2+
from langchain_text_splitters import SpacyTextSplitter
from sentence_transformers import SentenceTransformer
from core.BaseChunker import BaseChunker
logger = logging.getLogger(__name__)
class SemanticChunker(BaseChunker):
"""Chunks text based on semantic similarity and size constraints"""
def __init__(
self,
model_name: Optional[str] = None,
embedding_model: Optional[Any] = None,
chunk_size: int = 200,
chunk_overlap: int = 0,
similarity_threshold: float = 0.9,
separator: str = " "
):
"""
Initialize the semantic chunker with configurable parameters
"""
# Validate parameters
if chunk_size <= 0:
raise ValueError("chunk_size must be a positive integer.")
if not (0 <= similarity_threshold <= 1):
raise ValueError("similarity_threshold must be between 0 and 1.")
# Initialize BaseChunker first
super().__init__(model_name, embedding_model)
# Set semantic chunking parameters
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.similarity_threshold = similarity_threshold
self.separator = separator
# Use provided embedding model or initialize sentence transformer
is_dummy = False
if embedding_model is not None:
try:
test_output = embedding_model.encode("test")
if isinstance(test_output, list) and len(test_output) == 384 and all(x == 0.0 for x in test_output):
is_dummy = True
except:
pass
if embedding_model is None or is_dummy:
try:
self.sentence_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
self.embedding_model = self.sentence_model
logger.info("Initialized SentenceTransformer for semantic chunking")
except Exception as e:
logger.error(f"Error loading SentenceTransformer: {e}")
class DummyEmbedder:
def encode(self, text, **kwargs):
return [0.0] * 384
self.sentence_model = DummyEmbedder()
self.embedding_model = self.sentence_model
else:
self.sentence_model = embedding_model
logger.info("Using provided embedding model for semantic chunking")
# Initialize text splitter for initial chunking
self.text_splitter = SpacyTextSplitter(
chunk_size=self.chunk_size - self.chunk_overlap,
chunk_overlap=self.chunk_overlap,
separator=self.separator
)
def _enforce_size_immediately(self, text: str) -> List[str]:
if not text.strip():
return []
chunks = []
current_chunk = []
words = text.split()
for word in words:
if sum(len(w) for w in current_chunk) + len(word) + len(current_chunk) <= self.chunk_size:
current_chunk.append(word)
else:
if current_chunk:
chunks.append(" ".join(current_chunk))
current_chunk = [word]
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
def get_semantic_chunks(self, documents: List[Document]) -> List[Document]:
if not documents:
logger.warning("No documents provided for semantic chunking")
return []
try:
base_chunks = self.text_splitter.split_documents(documents)
logger.info(f"Initial splitting created {len(base_chunks)} base chunks")
if not base_chunks:
return []
chunk_contents = [doc.page_content for doc in base_chunks]
chunk_embeddings = self.sentence_model.encode(chunk_contents)
grouped_chunks = []
current_group = []
current_embedding = None
for i, base_chunk in enumerate(base_chunks):
if not current_group:
current_group.append(base_chunk)
current_embedding = chunk_embeddings[i].reshape(1, -1)
continue
similarity = cosine_similarity(current_embedding, chunk_embeddings[i].reshape(1, -1))[0][0]
combined_content = " ".join([doc.page_content for doc in current_group] + [base_chunk.page_content])
if similarity >= self.similarity_threshold and len(combined_content) <= self.chunk_size:
current_group.append(base_chunk)
else:
grouped_chunks.extend(self._finalize_chunk_group(current_group))
current_group = [base_chunk]
current_embedding = chunk_embeddings[i].reshape(1, -1)
if current_group:
grouped_chunks.extend(self._finalize_chunk_group(current_group))
logger.info(f"Created {len(grouped_chunks)} semantic chunks")
return grouped_chunks
except Exception as e:
logger.error(f"Error in semantic chunking: {e}")
return documents
def _finalize_chunk_group(self, group: List[Document]) -> List[Document]:
if not group:
return []
processed_chunks = []
content = " ".join([doc.page_content for doc in group])
size_limited_chunks = self._enforce_size_immediately(content)
base_metadata = group[0].metadata.copy()
for i, chunk in enumerate(size_limited_chunks):
stats = self.analyze_text(chunk)
metadata = base_metadata.copy()
metadata.update({
"chunk_index": i + 1,
"chunk_count": len(size_limited_chunks),
"char_count": stats["char_count"],
"token_count": stats["token_count"],
"sentence_count": stats["sentence_count"],
"word_count": stats["word_count"],
"chunk_type": "semantic"
})
processed_chunks.append(Document(page_content=chunk, metadata=metadata))
return processed_chunks
def semantic_process_document(self, file_path: str, preprocess: bool = False) -> List[Document]:
try:
logger.info(f"Processing document with semantic chunking: {file_path}")
raw_documents = self.load_document(file_path)
processed_documents = []
for doc in raw_documents:
content = doc.page_content
if preprocess:
content = self.preprocess_text(content)
processed_documents.append(Document(
page_content=content,
metadata=doc.metadata
))
documents = self.get_semantic_chunks(processed_documents)
logger.info(f"Created {len(documents)} semantic chunks")
return documents
except Exception as e:
logger.error(f"Error in semantic_process_document: {e}")
raise
def process_document(self, file_path: str, preprocess: bool = True) -> List[Document]:
return self.semantic_process_document(file_path, preprocess)