# core/chunking.py import numpy as np from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity import logging import re from typing import List, Optional logger = logging.getLogger(__name__) def _split_into_sentences(text: str) -> List[str]: """ Improved sentence splitting that handles common edge cases. """ # Handle common abbreviations that shouldn't cause splits abbreviations = [ 'Dr', 'Mr', 'Mrs', 'Ms', 'Prof', 'Sr', 'Jr', 'vs', 'etc', 'Inc', 'Ltd', 'Corp', 'U.S', 'U.K', 'U.N', 'E.U', 'NASA', 'FBI', 'CIA', 'GDP', 'CEO', 'CFO', 'CTO' ] # Temporarily replace abbreviations to protect them from splitting protected_text = text replacements = {} for i, abbr in enumerate(abbreviations): placeholder = f"__ABBR_{i}__" protected_text = re.sub(rf'\b{re.escape(abbr)}\.', placeholder, protected_text, flags=re.IGNORECASE) replacements[placeholder] = f"{abbr}." # Split on sentence-ending punctuation followed by whitespace or end of string sentence_pattern = r'[.!?]+(?:\s+|$)' sentences = re.split(sentence_pattern, protected_text) # Restore abbreviations and clean up cleaned_sentences = [] for sentence in sentences: if sentence.strip(): # Restore abbreviations for placeholder, original in replacements.items(): sentence = sentence.replace(placeholder, original) cleaned_sentences.append(sentence.strip()) return cleaned_sentences def _calculate_rolling_similarity(embeddings: np.ndarray, window_size: int = 3) -> List[float]: """ Calculate rolling average similarity to smooth out noise and capture broader semantic shifts. """ similarities = [] for i in range(1, len(embeddings)): # Calculate similarity between current and previous sentence current_sim = cosine_similarity( embeddings[i].reshape(1, -1), embeddings[i-1].reshape(1, -1) )[0, 0] similarities.append(current_sim) # Apply rolling average to smooth similarities if len(similarities) <= window_size: return similarities smoothed = [] for i in range(len(similarities)): start_idx = max(0, i - window_size // 2) end_idx = min(len(similarities), i + window_size // 2 + 1) window_similarities = similarities[start_idx:end_idx] smoothed.append(np.mean(window_similarities)) return smoothed def _adaptive_threshold(similarities: List[float], base_threshold: float = 0.55) -> float: """ Dynamically adjust threshold based on the distribution of similarities in the text. """ if not similarities: return base_threshold mean_sim = np.mean(similarities) std_sim = np.std(similarities) # Adjust threshold based on text characteristics # If similarities are generally high, use a higher threshold # If similarities vary a lot, be more conservative adjusted_threshold = max( base_threshold, mean_sim - (0.5 * std_sim) ) return min(adjusted_threshold, 0.8) # Cap at 0.8 to avoid over-splitting def semantic_chunker( text: str, model: SentenceTransformer, similarity_threshold: float = 0.55, min_chunk_size: int = 50, max_chunk_size: int = 1000, adaptive_threshold_enabled: bool = True ) -> List[str]: """ Enhanced semantic chunking with improved sentence splitting, adaptive thresholding, and chunk size controls. Args: text: Input text to chunk model: SentenceTransformer model for embeddings similarity_threshold: Base threshold for semantic breaks min_chunk_size: Minimum characters per chunk max_chunk_size: Maximum characters per chunk adaptive_threshold_enabled: Whether to use adaptive thresholding Returns: List of text chunks """ logger.info("Starting enhanced semantic chunking...") if not text or not text.strip(): logger.warning("Empty or whitespace-only text provided") return [] # Improved sentence splitting sentences = _split_into_sentences(text) if len(sentences) <= 1: logger.info("Text contains only one sentence, returning as single chunk") return [text.strip()] logger.info(f"Split text into {len(sentences)} sentences") try: # Generate embeddings with error handling embeddings = model.encode(sentences, convert_to_numpy=True, show_progress_bar=False) logger.info("Generated sentence embeddings") except Exception as e: logger.error(f"Failed to generate embeddings: {e}") # Fallback to simple splitting if embeddings fail return [text] # Calculate smoothed similarities similarities = _calculate_rolling_similarity(embeddings) if not similarities: return [text.strip()] # Adaptive threshold adjustment if adaptive_threshold_enabled: threshold = _adaptive_threshold(similarities, similarity_threshold) logger.info(f"Adjusted threshold from {similarity_threshold:.3f} to {threshold:.3f}") else: threshold = similarity_threshold # Enhanced chunking with size constraints chunks = [] current_chunk_sentences = [sentences[0]] current_chunk_length = len(sentences[0]) for i, similarity in enumerate(similarities): sentence_idx = i + 1 # similarities[i] compares sentence[i+1] with sentence[i] sentence = sentences[sentence_idx] sentence_length = len(sentence) # Check if we should create a new chunk should_break = False # Semantic break condition if similarity < threshold: should_break = True # Maximum size constraint - force break if adding sentence exceeds max size elif current_chunk_length + sentence_length > max_chunk_size: should_break = True # If we decide to break, finalize current chunk if should_break and current_chunk_sentences: chunk_text = " ".join(current_chunk_sentences) # Only add chunk if it meets minimum size, otherwise merge with next if len(chunk_text) >= min_chunk_size or not chunks: chunks.append(chunk_text) current_chunk_sentences = [] current_chunk_length = 0 # Add current sentence to chunk current_chunk_sentences.append(sentence) current_chunk_length += sentence_length + 1 # +1 for space # Handle final chunk if current_chunk_sentences: final_chunk = " ".join(current_chunk_sentences) # If final chunk is too small, merge with previous chunk if len(final_chunk) < min_chunk_size and chunks: chunks[-1] = chunks[-1] + " " + final_chunk else: chunks.append(final_chunk) # Post-processing: ensure no chunks are too large final_chunks = [] for chunk in chunks: if len(chunk) <= max_chunk_size: final_chunks.append(chunk) else: # Split oversized chunks at sentence boundaries chunk_sentences = _split_into_sentences(chunk) temp_chunk = "" for sent in chunk_sentences: if len(temp_chunk) + len(sent) <= max_chunk_size: temp_chunk += (" " + sent) if temp_chunk else sent else: if temp_chunk: final_chunks.append(temp_chunk) temp_chunk = sent if temp_chunk: final_chunks.append(temp_chunk) logger.info(f"Enhanced semantic chunking resulted in {len(final_chunks)} chunks") # Log chunk statistics for debugging if final_chunks: chunk_lengths = [len(chunk) for chunk in final_chunks] logger.debug(f"Chunk length stats - Min: {min(chunk_lengths)}, " f"Max: {max(chunk_lengths)}, Mean: {np.mean(chunk_lengths):.1f}") return final_chunks