|
|
|
|
|
|
|
|
import numpy as np |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
import logging |
|
|
import re |
|
|
from typing import List, Optional |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def _split_into_sentences(text: str) -> List[str]: |
|
|
""" |
|
|
Improved sentence splitting that handles common edge cases. |
|
|
""" |
|
|
|
|
|
abbreviations = [ |
|
|
'Dr', 'Mr', 'Mrs', 'Ms', 'Prof', 'Sr', 'Jr', 'vs', 'etc', 'Inc', 'Ltd', 'Corp', |
|
|
'U.S', 'U.K', 'U.N', 'E.U', 'NASA', 'FBI', 'CIA', 'GDP', 'CEO', 'CFO', 'CTO' |
|
|
] |
|
|
|
|
|
|
|
|
protected_text = text |
|
|
replacements = {} |
|
|
for i, abbr in enumerate(abbreviations): |
|
|
placeholder = f"__ABBR_{i}__" |
|
|
protected_text = re.sub(rf'\b{re.escape(abbr)}\.', placeholder, protected_text, flags=re.IGNORECASE) |
|
|
replacements[placeholder] = f"{abbr}." |
|
|
|
|
|
|
|
|
sentence_pattern = r'[.!?]+(?:\s+|$)' |
|
|
sentences = re.split(sentence_pattern, protected_text) |
|
|
|
|
|
|
|
|
cleaned_sentences = [] |
|
|
for sentence in sentences: |
|
|
if sentence.strip(): |
|
|
|
|
|
for placeholder, original in replacements.items(): |
|
|
sentence = sentence.replace(placeholder, original) |
|
|
cleaned_sentences.append(sentence.strip()) |
|
|
|
|
|
return cleaned_sentences |
|
|
|
|
|
|
|
|
def _calculate_rolling_similarity(embeddings: np.ndarray, window_size: int = 3) -> List[float]: |
|
|
""" |
|
|
Calculate rolling average similarity to smooth out noise and capture broader semantic shifts. |
|
|
""" |
|
|
similarities = [] |
|
|
|
|
|
for i in range(1, len(embeddings)): |
|
|
|
|
|
current_sim = cosine_similarity( |
|
|
embeddings[i].reshape(1, -1), |
|
|
embeddings[i-1].reshape(1, -1) |
|
|
)[0, 0] |
|
|
similarities.append(current_sim) |
|
|
|
|
|
|
|
|
if len(similarities) <= window_size: |
|
|
return similarities |
|
|
|
|
|
smoothed = [] |
|
|
for i in range(len(similarities)): |
|
|
start_idx = max(0, i - window_size // 2) |
|
|
end_idx = min(len(similarities), i + window_size // 2 + 1) |
|
|
window_similarities = similarities[start_idx:end_idx] |
|
|
smoothed.append(np.mean(window_similarities)) |
|
|
|
|
|
return smoothed |
|
|
|
|
|
|
|
|
def _adaptive_threshold(similarities: List[float], base_threshold: float = 0.55) -> float: |
|
|
""" |
|
|
Dynamically adjust threshold based on the distribution of similarities in the text. |
|
|
""" |
|
|
if not similarities: |
|
|
return base_threshold |
|
|
|
|
|
mean_sim = np.mean(similarities) |
|
|
std_sim = np.std(similarities) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
adjusted_threshold = max( |
|
|
base_threshold, |
|
|
mean_sim - (0.5 * std_sim) |
|
|
) |
|
|
|
|
|
return min(adjusted_threshold, 0.8) |
|
|
|
|
|
|
|
|
def semantic_chunker( |
|
|
text: str, |
|
|
model: SentenceTransformer, |
|
|
similarity_threshold: float = 0.55, |
|
|
min_chunk_size: int = 50, |
|
|
max_chunk_size: int = 1000, |
|
|
adaptive_threshold_enabled: bool = True |
|
|
) -> List[str]: |
|
|
""" |
|
|
Enhanced semantic chunking with improved sentence splitting, adaptive thresholding, |
|
|
and chunk size controls. |
|
|
|
|
|
Args: |
|
|
text: Input text to chunk |
|
|
model: SentenceTransformer model for embeddings |
|
|
similarity_threshold: Base threshold for semantic breaks |
|
|
min_chunk_size: Minimum characters per chunk |
|
|
max_chunk_size: Maximum characters per chunk |
|
|
adaptive_threshold_enabled: Whether to use adaptive thresholding |
|
|
|
|
|
Returns: |
|
|
List of text chunks |
|
|
""" |
|
|
logger.info("Starting enhanced semantic chunking...") |
|
|
|
|
|
if not text or not text.strip(): |
|
|
logger.warning("Empty or whitespace-only text provided") |
|
|
return [] |
|
|
|
|
|
|
|
|
sentences = _split_into_sentences(text) |
|
|
|
|
|
if len(sentences) <= 1: |
|
|
logger.info("Text contains only one sentence, returning as single chunk") |
|
|
return [text.strip()] |
|
|
|
|
|
logger.info(f"Split text into {len(sentences)} sentences") |
|
|
|
|
|
try: |
|
|
|
|
|
embeddings = model.encode(sentences, convert_to_numpy=True, show_progress_bar=False) |
|
|
logger.info("Generated sentence embeddings") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to generate embeddings: {e}") |
|
|
|
|
|
return [text] |
|
|
|
|
|
|
|
|
similarities = _calculate_rolling_similarity(embeddings) |
|
|
|
|
|
if not similarities: |
|
|
return [text.strip()] |
|
|
|
|
|
|
|
|
if adaptive_threshold_enabled: |
|
|
threshold = _adaptive_threshold(similarities, similarity_threshold) |
|
|
logger.info(f"Adjusted threshold from {similarity_threshold:.3f} to {threshold:.3f}") |
|
|
else: |
|
|
threshold = similarity_threshold |
|
|
|
|
|
|
|
|
chunks = [] |
|
|
current_chunk_sentences = [sentences[0]] |
|
|
current_chunk_length = len(sentences[0]) |
|
|
|
|
|
for i, similarity in enumerate(similarities): |
|
|
sentence_idx = i + 1 |
|
|
sentence = sentences[sentence_idx] |
|
|
sentence_length = len(sentence) |
|
|
|
|
|
|
|
|
should_break = False |
|
|
|
|
|
|
|
|
if similarity < threshold: |
|
|
should_break = True |
|
|
|
|
|
|
|
|
elif current_chunk_length + sentence_length > max_chunk_size: |
|
|
should_break = True |
|
|
|
|
|
|
|
|
if should_break and current_chunk_sentences: |
|
|
chunk_text = " ".join(current_chunk_sentences) |
|
|
|
|
|
|
|
|
if len(chunk_text) >= min_chunk_size or not chunks: |
|
|
chunks.append(chunk_text) |
|
|
current_chunk_sentences = [] |
|
|
current_chunk_length = 0 |
|
|
|
|
|
|
|
|
current_chunk_sentences.append(sentence) |
|
|
current_chunk_length += sentence_length + 1 |
|
|
|
|
|
|
|
|
if current_chunk_sentences: |
|
|
final_chunk = " ".join(current_chunk_sentences) |
|
|
|
|
|
|
|
|
if len(final_chunk) < min_chunk_size and chunks: |
|
|
chunks[-1] = chunks[-1] + " " + final_chunk |
|
|
else: |
|
|
chunks.append(final_chunk) |
|
|
|
|
|
|
|
|
final_chunks = [] |
|
|
for chunk in chunks: |
|
|
if len(chunk) <= max_chunk_size: |
|
|
final_chunks.append(chunk) |
|
|
else: |
|
|
|
|
|
chunk_sentences = _split_into_sentences(chunk) |
|
|
temp_chunk = "" |
|
|
|
|
|
for sent in chunk_sentences: |
|
|
if len(temp_chunk) + len(sent) <= max_chunk_size: |
|
|
temp_chunk += (" " + sent) if temp_chunk else sent |
|
|
else: |
|
|
if temp_chunk: |
|
|
final_chunks.append(temp_chunk) |
|
|
temp_chunk = sent |
|
|
|
|
|
if temp_chunk: |
|
|
final_chunks.append(temp_chunk) |
|
|
|
|
|
logger.info(f"Enhanced semantic chunking resulted in {len(final_chunks)} chunks") |
|
|
|
|
|
|
|
|
if final_chunks: |
|
|
chunk_lengths = [len(chunk) for chunk in final_chunks] |
|
|
logger.debug(f"Chunk length stats - Min: {min(chunk_lengths)}, " |
|
|
f"Max: {max(chunk_lengths)}, Mean: {np.mean(chunk_lengths):.1f}") |
|
|
|
|
|
return final_chunks |
|
|
|