Spaces:
Sleeping
Sleeping
| """ | |
| SemanticChunker.py | |
| A module for semantic-aware text chunking using embeddings and similarity metrics. | |
| """ | |
| import logging | |
| from typing import List, Optional, Any | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from langchain_core.documents import Document | |
| # FIXED IMPORT: Updated for LangChain v0.2+ | |
| from langchain_text_splitters import SpacyTextSplitter | |
| from sentence_transformers import SentenceTransformer | |
| from core.BaseChunker import BaseChunker | |
| logger = logging.getLogger(__name__) | |
| class SemanticChunker(BaseChunker): | |
| """Chunks text based on semantic similarity and size constraints""" | |
| def __init__( | |
| self, | |
| model_name: Optional[str] = None, | |
| embedding_model: Optional[Any] = None, | |
| chunk_size: int = 200, | |
| chunk_overlap: int = 0, | |
| similarity_threshold: float = 0.9, | |
| separator: str = " " | |
| ): | |
| """ | |
| Initialize the semantic chunker with configurable parameters | |
| """ | |
| # Validate parameters | |
| if chunk_size <= 0: | |
| raise ValueError("chunk_size must be a positive integer.") | |
| if not (0 <= similarity_threshold <= 1): | |
| raise ValueError("similarity_threshold must be between 0 and 1.") | |
| # Initialize BaseChunker first | |
| super().__init__(model_name, embedding_model) | |
| # Set semantic chunking parameters | |
| self.chunk_size = chunk_size | |
| self.chunk_overlap = chunk_overlap | |
| self.similarity_threshold = similarity_threshold | |
| self.separator = separator | |
| # Use provided embedding model or initialize sentence transformer | |
| is_dummy = False | |
| if embedding_model is not None: | |
| try: | |
| test_output = embedding_model.encode("test") | |
| if isinstance(test_output, list) and len(test_output) == 384 and all(x == 0.0 for x in test_output): | |
| is_dummy = True | |
| except: | |
| pass | |
| if embedding_model is None or is_dummy: | |
| try: | |
| self.sentence_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1') | |
| self.embedding_model = self.sentence_model | |
| logger.info("Initialized SentenceTransformer for semantic chunking") | |
| except Exception as e: | |
| logger.error(f"Error loading SentenceTransformer: {e}") | |
| class DummyEmbedder: | |
| def encode(self, text, **kwargs): | |
| return [0.0] * 384 | |
| self.sentence_model = DummyEmbedder() | |
| self.embedding_model = self.sentence_model | |
| else: | |
| self.sentence_model = embedding_model | |
| logger.info("Using provided embedding model for semantic chunking") | |
| # Initialize text splitter for initial chunking | |
| self.text_splitter = SpacyTextSplitter( | |
| chunk_size=self.chunk_size - self.chunk_overlap, | |
| chunk_overlap=self.chunk_overlap, | |
| separator=self.separator | |
| ) | |
| def _enforce_size_immediately(self, text: str) -> List[str]: | |
| if not text.strip(): | |
| return [] | |
| chunks = [] | |
| current_chunk = [] | |
| words = text.split() | |
| for word in words: | |
| if sum(len(w) for w in current_chunk) + len(word) + len(current_chunk) <= self.chunk_size: | |
| current_chunk.append(word) | |
| else: | |
| if current_chunk: | |
| chunks.append(" ".join(current_chunk)) | |
| current_chunk = [word] | |
| if current_chunk: | |
| chunks.append(" ".join(current_chunk)) | |
| return chunks | |
| def get_semantic_chunks(self, documents: List[Document]) -> List[Document]: | |
| if not documents: | |
| logger.warning("No documents provided for semantic chunking") | |
| return [] | |
| try: | |
| base_chunks = self.text_splitter.split_documents(documents) | |
| logger.info(f"Initial splitting created {len(base_chunks)} base chunks") | |
| if not base_chunks: | |
| return [] | |
| chunk_contents = [doc.page_content for doc in base_chunks] | |
| chunk_embeddings = self.sentence_model.encode(chunk_contents) | |
| grouped_chunks = [] | |
| current_group = [] | |
| current_embedding = None | |
| for i, base_chunk in enumerate(base_chunks): | |
| if not current_group: | |
| current_group.append(base_chunk) | |
| current_embedding = chunk_embeddings[i].reshape(1, -1) | |
| continue | |
| similarity = cosine_similarity(current_embedding, chunk_embeddings[i].reshape(1, -1))[0][0] | |
| combined_content = " ".join([doc.page_content for doc in current_group] + [base_chunk.page_content]) | |
| if similarity >= self.similarity_threshold and len(combined_content) <= self.chunk_size: | |
| current_group.append(base_chunk) | |
| else: | |
| grouped_chunks.extend(self._finalize_chunk_group(current_group)) | |
| current_group = [base_chunk] | |
| current_embedding = chunk_embeddings[i].reshape(1, -1) | |
| if current_group: | |
| grouped_chunks.extend(self._finalize_chunk_group(current_group)) | |
| logger.info(f"Created {len(grouped_chunks)} semantic chunks") | |
| return grouped_chunks | |
| except Exception as e: | |
| logger.error(f"Error in semantic chunking: {e}") | |
| return documents | |
| def _finalize_chunk_group(self, group: List[Document]) -> List[Document]: | |
| if not group: | |
| return [] | |
| processed_chunks = [] | |
| content = " ".join([doc.page_content for doc in group]) | |
| size_limited_chunks = self._enforce_size_immediately(content) | |
| base_metadata = group[0].metadata.copy() | |
| for i, chunk in enumerate(size_limited_chunks): | |
| stats = self.analyze_text(chunk) | |
| metadata = base_metadata.copy() | |
| metadata.update({ | |
| "chunk_index": i + 1, | |
| "chunk_count": len(size_limited_chunks), | |
| "char_count": stats["char_count"], | |
| "token_count": stats["token_count"], | |
| "sentence_count": stats["sentence_count"], | |
| "word_count": stats["word_count"], | |
| "chunk_type": "semantic" | |
| }) | |
| processed_chunks.append(Document(page_content=chunk, metadata=metadata)) | |
| return processed_chunks | |
| def semantic_process_document(self, file_path: str, preprocess: bool = False) -> List[Document]: | |
| try: | |
| logger.info(f"Processing document with semantic chunking: {file_path}") | |
| raw_documents = self.load_document(file_path) | |
| processed_documents = [] | |
| for doc in raw_documents: | |
| content = doc.page_content | |
| if preprocess: | |
| content = self.preprocess_text(content) | |
| processed_documents.append(Document( | |
| page_content=content, | |
| metadata=doc.metadata | |
| )) | |
| documents = self.get_semantic_chunks(processed_documents) | |
| logger.info(f"Created {len(documents)} semantic chunks") | |
| return documents | |
| except Exception as e: | |
| logger.error(f"Error in semantic_process_document: {e}") | |
| raise | |
| def process_document(self, file_path: str, preprocess: bool = True) -> List[Document]: | |
| return self.semantic_process_document(file_path, preprocess) |