import re
from typing import List, Tuple, AsyncIterator, Optional, Union, Dict, Any
import pysbd
from loguru import logger
from langdetect import detect
from enum import Enum
from dataclasses import dataclass
# Constants for additional checks
COMMAS = [
",",
"،",
",",
"、",
"፣",
"၊",
";",
"΄",
"‛",
"।",
"﹐",
"꓾",
"⹁",
"︐",
"﹑",
"、",
"،",
]
END_PUNCTUATIONS = [".", "!", "?", "。", "!", "?", "...", "。。。"]
ABBREVIATIONS = [
"Mr.",
"Mrs.",
"Dr.",
"Prof.",
"Inc.",
"Ltd.",
"Jr.",
"Sr.",
"e.g.",
"i.e.",
"vs.",
"St.",
"Rd.",
"Dr.",
]
# Set of languages directly supported by pysbd
SUPPORTED_LANGUAGES = {
"am",
"ar",
"bg",
"da",
"de",
"el",
"en",
"es",
"fa",
"fr",
"hi",
"hy",
"it",
"ja",
"kk",
"mr",
"my",
"nl",
"pl",
"ru",
"sk",
"ur",
"zh",
}
def detect_language(text: str) -> str:
"""
Detect text language and check if it's supported by pysbd.
Returns None for unsupported languages.
"""
try:
detected = detect(text)
return detected if detected in SUPPORTED_LANGUAGES else None
except Exception as e:
logger.debug(f"Language detection failed, language not supported by pysdb: {e}")
return None
def is_complete_sentence(text: str) -> bool:
"""
Check if text ends with sentence-ending punctuation and not abbreviation.
Args:
text: Text to check
Returns:
bool: Whether the text is a complete sentence
"""
text = text.strip()
if not text:
return False
if any(text.endswith(abbrev) for abbrev in ABBREVIATIONS):
return False
return any(text.endswith(punct) for punct in END_PUNCTUATIONS)
def contains_comma(text: str) -> bool:
"""
Check if text contains any comma.
Args:
text: Text to check
Returns:
bool: Whether the text contains a comma
"""
return any(comma in text for comma in COMMAS)
def comma_splitter(text: str) -> Tuple[str, str]:
"""
Process text and split it at the first comma.
Returns the split text (including the comma) and the remaining text.
Args:
text: Text to split
Returns:
Tuple[str, str]: (split text with comma, remaining text)
"""
if not text:
return [], ""
for comma in COMMAS:
if comma in text:
split_text = text.split(comma, 1)
# Return first part with the comma
return split_text[0].strip() + comma, split_text[1].strip()
return text, ""
def has_punctuation(text: str) -> bool:
"""
Check if the text is a punctuation mark.
Args:
text: Text to check
Returns:
bool: Whether the text is a punctuation mark
"""
for punct in COMMAS + END_PUNCTUATIONS:
if punct in text:
return True
return False
def contains_end_punctuation(text: str) -> bool:
"""
Check if text contains any sentence-ending punctuation.
Args:
text: Text to check
Returns:
bool: Whether the text contains ending punctuation
"""
return any(punct in text for punct in END_PUNCTUATIONS)
def segment_text_by_regex(text: str) -> Tuple[List[str], str]:
"""
Segment text into complete sentences using regex pattern matching.
More efficient but less accurate than pysbd.
Args:
text: Text to segment into sentences
Returns:
Tuple[List[str], str]: (list of complete sentences, remaining incomplete text)
"""
if not text:
return [], ""
complete_sentences = []
remaining_text = text.strip()
# Create pattern for matching sentences ending with any end punctuation
escaped_punctuations = [re.escape(p) for p in END_PUNCTUATIONS]
pattern = r"(.*?(?:[" + "|".join(escaped_punctuations) + r"]))"
while remaining_text:
match = re.search(pattern, remaining_text)
if not match:
break
end_pos = match.end(1)
potential_sentence = remaining_text[:end_pos].strip()
# Skip if sentence ends with abbreviation
if any(potential_sentence.endswith(abbrev) for abbrev in ABBREVIATIONS):
remaining_text = remaining_text[end_pos:].lstrip()
continue
complete_sentences.append(potential_sentence)
remaining_text = remaining_text[end_pos:].lstrip()
return complete_sentences, remaining_text
def segment_text_by_pysbd(text: str) -> Tuple[List[str], str]:
"""
Segment text into complete sentences and remaining text.
Uses pysbd for supported languages, falls back to regex for others.
Args:
text: Text to segment into sentences
Returns:
Tuple[List[str], str]: (list of complete sentences, remaining incomplete text)
"""
if not text:
return [], ""
try:
# Detect language
lang = detect_language(text)
if lang is not None:
# Use pysbd for supported languages
segmenter = pysbd.Segmenter(language=lang, clean=False)
sentences = segmenter.segment(text)
if not sentences:
return [], text
# Process all but the last sentence
complete_sentences = []
for sent in sentences[:-1]:
sent = sent.strip()
if sent:
complete_sentences.append(sent)
# Handle the last sentence
last_sent = sentences[-1].strip()
if is_complete_sentence(last_sent):
complete_sentences.append(last_sent)
remaining = ""
else:
remaining = last_sent
else:
# Use regex for unsupported languages
return segment_text_by_regex(text)
logger.debug(
f"Processed sentences: {complete_sentences}, Remaining: {remaining}"
)
return complete_sentences, remaining
except Exception as e:
logger.error(f"Error in sentence segmentation: {e}")
# Fallback to regex on any error
return segment_text_by_regex(text)
class TagState(Enum):
"""State of a tag in text"""
START = "start" #
INSIDE = "inside" # text between tags
END = "end" #
SELF_CLOSING = "self" #
NONE = "none" # no tag
@dataclass
class TagInfo:
"""Information about a tag"""
name: str
state: TagState
def __str__(self) -> str:
"""String representation of tag info"""
if self.state == TagState.NONE:
return "none"
return f"{self.name}:{self.state.value}"
@dataclass
class SentenceWithTags:
"""A sentence with its tag information, supporting nested tags"""
text: str
tags: List[TagInfo] # List of tags from outermost to innermost
class SentenceDivider:
def __init__(
self,
faster_first_response: bool = True,
segment_method: str = "pysbd",
valid_tags: List[str] = None,
):
"""
Initialize the SentenceDivider.
Args:
faster_first_response: Whether to split first sentence at commas
segment_method: Method for segmenting sentences
valid_tags: List of valid tag names to detect
"""
self.faster_first_response = faster_first_response
self.segment_method = segment_method
self.valid_tags = valid_tags or ["think"]
self._is_first_sentence = True
self._buffer = ""
# Replace active_tags dict with a stack to handle nesting
self._tag_stack = []
def _get_current_tags(self) -> List[TagInfo]:
"""
Get all current active tags from outermost to innermost.
Returns:
List[TagInfo]: List of active tags
"""
return [TagInfo(tag.name, TagState.INSIDE) for tag in self._tag_stack]
def _get_current_tag(self) -> Optional[TagInfo]:
"""
Get the current innermost active tag.
Returns:
TagInfo if there's an active tag, None otherwise
"""
return self._tag_stack[-1] if self._tag_stack else None
def _extract_tag(self, text: str) -> Tuple[Optional[TagInfo], str]:
"""
Extract the first tag from text if present.
Handles nested tags by maintaining a tag stack.
Args:
text: Text to check for tags
Returns:
Tuple of (TagInfo if tag found else None, remaining text)
"""
# Find the first occurrence of any tag
first_tag = None
first_pos = len(text)
tag_type = None
matched_tag = None
# Check for self-closing tags
for tag in self.valid_tags:
pattern = f"<{tag}/>"
match = re.search(pattern, text)
if match and match.start() < first_pos:
first_pos = match.start()
first_tag = match
tag_type = TagState.SELF_CLOSING
matched_tag = tag
# Check for opening tags
for tag in self.valid_tags:
pattern = f"<{tag}>"
match = re.search(pattern, text)
if match and match.start() < first_pos:
first_pos = match.start()
first_tag = match
tag_type = TagState.START
matched_tag = tag
# Check for closing tags
for tag in self.valid_tags:
pattern = f"{tag}>"
match = re.search(pattern, text)
if match and match.start() < first_pos:
first_pos = match.start()
first_tag = match
tag_type = TagState.END
matched_tag = tag
if not first_tag:
return None, text
# Handle the found tag
if tag_type == TagState.START:
# Push new tag onto stack
self._tag_stack.append(TagInfo(matched_tag, TagState.START))
elif tag_type == TagState.END:
# Verify matching tags
if not self._tag_stack or self._tag_stack[-1].name != matched_tag:
logger.warning(f"Mismatched closing tag: {matched_tag}")
else:
self._tag_stack.pop()
return (TagInfo(matched_tag, tag_type), text[first_tag.end() :].lstrip())
async def _process_buffer(self) -> AsyncIterator[SentenceWithTags]:
"""
Process the current buffer, yielding complete sentences with tags.
This is now an async generator.
It consumes processed parts from self._buffer.
"""
processed_something = True # Flag to loop until no more processing can be done
while processed_something:
processed_something = False
original_buffer_len = len(self._buffer)
if not self._buffer.strip():
break
# Find the next tag position
next_tag_pos = len(self._buffer)
tag_pattern_found = None
for tag in self.valid_tags:
patterns = [f"<{tag}>", f"{tag}>", f"<{tag}/>"]
for pattern in patterns:
pos = self._buffer.find(pattern)
if pos != -1 and pos < next_tag_pos:
next_tag_pos = pos
tag_pattern_found = pattern # Store the found pattern
if next_tag_pos == 0:
# Tag is at the start of buffer
tag_info, remaining = self._extract_tag(self._buffer)
if tag_info:
processed_text = self._buffer[
: len(self._buffer) - len(remaining)
].strip()
# Yield the tag itself, represented as a SentenceWithTags
yield SentenceWithTags(text=processed_text, tags=[tag_info])
self._buffer = remaining
processed_something = True
continue # Restart processing loop for the remaining buffer
elif next_tag_pos < len(self._buffer):
# Tag is in the middle - process text before tag first
text_before_tag = self._buffer[:next_tag_pos]
current_tags = self._get_current_tags()
processed_segment = ""
# Process complete sentences in text before tag
if contains_end_punctuation(text_before_tag):
sentences, remaining_before = self._segment_text(text_before_tag)
for sentence in sentences:
if sentence.strip():
yield SentenceWithTags(
text=sentence.strip(),
tags=current_tags or [TagInfo("", TagState.NONE)],
)
# The part consumed includes sentences + what's left before the tag
processed_segment = text_before_tag
self._buffer = self._buffer[len(processed_segment) :]
processed_something = True
continue # Restart processing loop
elif text_before_tag.strip() and tag_pattern_found:
# No sentence end, but content exists AND we found a tag pattern after it.
# We can yield this segment because the tag provides a boundary.
yield SentenceWithTags(
text=text_before_tag.strip(),
tags=current_tags or [TagInfo("", TagState.NONE)],
)
self._buffer = self._buffer[len(text_before_tag) :]
processed_something = True
continue # Restart processing loop
# --- If no tag found after text_before_tag, we wait for more input or end punctuation ---
# Process the tag itself if we haven't continued
tag_info, remaining_after_tag = self._extract_tag(self._buffer)
if tag_info:
processed_tag_text = self._buffer[
: len(self._buffer) - len(remaining_after_tag)
].strip()
yield SentenceWithTags(text=processed_tag_text, tags=[tag_info])
self._buffer = remaining_after_tag
processed_something = True
continue # Restart processing loop
# No tags found or tag is not at the beginning/middle of processable segment
# Process normal text if buffer has changed or punctuation exists
if original_buffer_len > 0:
current_tags = self._get_current_tags()
# Handle first sentence with comma if enabled
if (
self._is_first_sentence
and self.faster_first_response
and contains_comma(self._buffer)
):
sentence, remaining = comma_splitter(self._buffer)
if sentence.strip():
yield SentenceWithTags(
text=sentence.strip(),
tags=current_tags or [TagInfo("", TagState.NONE)],
)
self._buffer = remaining
self._is_first_sentence = False
processed_something = True
continue # Restart processing loop
# Process normal sentences based on end punctuation
if contains_end_punctuation(self._buffer):
sentences, remaining = self._segment_text(self._buffer)
if sentences: # Only process if segmentation yielded sentences
self._buffer = remaining
self._is_first_sentence = False
processed_something = True
for sentence in sentences:
if sentence.strip():
yield SentenceWithTags(
text=sentence.strip(),
tags=current_tags or [TagInfo("", TagState.NONE)],
)
continue # Restart processing loop
# If we reached here without processing anything, break the loop
if not processed_something:
break
async def _flush_buffer(self) -> AsyncIterator[SentenceWithTags]:
"""
Process and yield all remaining content in the buffer at the end of the stream.
"""
logger.debug(f"Flushing remaining buffer: '{self._buffer}'")
# First, run _process_buffer to yield any standard sentences/tags
async for sentence in self._process_buffer():
yield sentence
# After processing standard structures, if anything is left, yield it as a final fragment
if self._buffer.strip():
logger.debug(
f"Yielding final fragment from buffer: '{self._buffer.strip()}'"
)
current_tags = self._get_current_tags()
yield SentenceWithTags(
text=self._buffer.strip(),
tags=current_tags or [TagInfo("", TagState.NONE)],
)
self._buffer = "" # Clear buffer after flushing
async def process_stream(
self, segment_stream: AsyncIterator[Union[str, Dict[str, Any]]]
) -> AsyncIterator[Union[SentenceWithTags, Dict[str, Any]]]:
"""
Process a stream of tokens (strings) and dictionaries.
Yields complete sentences with tags (SentenceWithTags) or dictionaries directly.
Args:
segment_stream: An async iterator yielding strings or dictionaries.
Yields:
Union[SentenceWithTags, Dict[str, Any]]: Complete sentences/tags or original dictionaries.
"""
self._full_response = []
self.reset() # Ensure state is clean
async for item in segment_stream:
if isinstance(item, dict):
# Before yielding the dict, process and yield any complete sentences formed so far
async for sentence in self._process_buffer():
self._full_response.append(
sentence.text
) # Track for complete response
yield sentence
# Now yield the dictionary
yield item
elif isinstance(item, str):
self._buffer += item
# Process the buffer incrementally as string chunks arrive
async for sentence in self._process_buffer():
self._full_response.append(
sentence.text
) # Track for complete response
yield sentence
else:
logger.warning(
f"SentenceDivider received unexpected type: {type(item)}"
)
# After the stream finishes, flush any remaining text in the buffer
async for sentence in self._flush_buffer():
self._full_response.append(sentence.text)
yield sentence
@property
def complete_response(self) -> str:
"""Get the complete response accumulated so far"""
return "".join(self._full_response)
def _segment_text(self, text: str) -> Tuple[List[str], str]:
"""Segment text using the configured method"""
if self.segment_method == "regex":
return segment_text_by_regex(text)
return segment_text_by_pysbd(text)
def reset(self):
"""Reset the divider state for a new conversation"""
self._is_first_sentence = True
self._buffer = ""
self._tag_stack = []