| import re
|
| from typing import List, Tuple, AsyncIterator, Optional, Union, Dict, Any
|
| import pysbd
|
| from loguru import logger
|
| from langdetect import detect
|
| from enum import Enum
|
| from dataclasses import dataclass
|
|
|
|
|
| COMMAS = [
|
| ",",
|
| "،",
|
| ",",
|
| "、",
|
| "፣",
|
| "၊",
|
| ";",
|
| "΄",
|
| "‛",
|
| "।",
|
| "﹐",
|
| "꓾",
|
| "⹁",
|
| "︐",
|
| "﹑",
|
| "、",
|
| "،",
|
| ]
|
|
|
| END_PUNCTUATIONS = [".", "!", "?", "。", "!", "?", "...", "。。。"]
|
| ABBREVIATIONS = [
|
| "Mr.",
|
| "Mrs.",
|
| "Dr.",
|
| "Prof.",
|
| "Inc.",
|
| "Ltd.",
|
| "Jr.",
|
| "Sr.",
|
| "e.g.",
|
| "i.e.",
|
| "vs.",
|
| "St.",
|
| "Rd.",
|
| "Dr.",
|
| ]
|
|
|
|
|
| SUPPORTED_LANGUAGES = {
|
| "am",
|
| "ar",
|
| "bg",
|
| "da",
|
| "de",
|
| "el",
|
| "en",
|
| "es",
|
| "fa",
|
| "fr",
|
| "hi",
|
| "hy",
|
| "it",
|
| "ja",
|
| "kk",
|
| "mr",
|
| "my",
|
| "nl",
|
| "pl",
|
| "ru",
|
| "sk",
|
| "ur",
|
| "zh",
|
| }
|
|
|
|
|
| def detect_language(text: str) -> str:
|
| """
|
| Detect text language and check if it's supported by pysbd.
|
| Returns None for unsupported languages.
|
| """
|
| try:
|
| detected = detect(text)
|
| return detected if detected in SUPPORTED_LANGUAGES else None
|
| except Exception as e:
|
| logger.debug(f"Language detection failed, language not supported by pysdb: {e}")
|
| return None
|
|
|
|
|
| def is_complete_sentence(text: str) -> bool:
|
| """
|
| Check if text ends with sentence-ending punctuation and not abbreviation.
|
|
|
| Args:
|
| text: Text to check
|
|
|
| Returns:
|
| bool: Whether the text is a complete sentence
|
| """
|
| text = text.strip()
|
| if not text:
|
| return False
|
|
|
| if any(text.endswith(abbrev) for abbrev in ABBREVIATIONS):
|
| return False
|
|
|
| return any(text.endswith(punct) for punct in END_PUNCTUATIONS)
|
|
|
|
|
| def contains_comma(text: str) -> bool:
|
| """
|
| Check if text contains any comma.
|
|
|
| Args:
|
| text: Text to check
|
|
|
| Returns:
|
| bool: Whether the text contains a comma
|
| """
|
| return any(comma in text for comma in COMMAS)
|
|
|
|
|
| def comma_splitter(text: str) -> Tuple[str, str]:
|
| """
|
| Process text and split it at the first comma.
|
| Returns the split text (including the comma) and the remaining text.
|
|
|
| Args:
|
| text: Text to split
|
|
|
| Returns:
|
| Tuple[str, str]: (split text with comma, remaining text)
|
| """
|
| if not text:
|
| return [], ""
|
|
|
| for comma in COMMAS:
|
| if comma in text:
|
| split_text = text.split(comma, 1)
|
|
|
| return split_text[0].strip() + comma, split_text[1].strip()
|
| return text, ""
|
|
|
|
|
| def has_punctuation(text: str) -> bool:
|
| """
|
| Check if the text is a punctuation mark.
|
|
|
| Args:
|
| text: Text to check
|
|
|
| Returns:
|
| bool: Whether the text is a punctuation mark
|
| """
|
| for punct in COMMAS + END_PUNCTUATIONS:
|
| if punct in text:
|
| return True
|
| return False
|
|
|
|
|
| def contains_end_punctuation(text: str) -> bool:
|
| """
|
| Check if text contains any sentence-ending punctuation.
|
|
|
| Args:
|
| text: Text to check
|
|
|
| Returns:
|
| bool: Whether the text contains ending punctuation
|
| """
|
| return any(punct in text for punct in END_PUNCTUATIONS)
|
|
|
|
|
| def segment_text_by_regex(text: str) -> Tuple[List[str], str]:
|
| """
|
| Segment text into complete sentences using regex pattern matching.
|
| More efficient but less accurate than pysbd.
|
|
|
| Args:
|
| text: Text to segment into sentences
|
|
|
| Returns:
|
| Tuple[List[str], str]: (list of complete sentences, remaining incomplete text)
|
| """
|
| if not text:
|
| return [], ""
|
|
|
| complete_sentences = []
|
| remaining_text = text.strip()
|
|
|
|
|
| escaped_punctuations = [re.escape(p) for p in END_PUNCTUATIONS]
|
| pattern = r"(.*?(?:[" + "|".join(escaped_punctuations) + r"]))"
|
|
|
| while remaining_text:
|
| match = re.search(pattern, remaining_text)
|
| if not match:
|
| break
|
|
|
| end_pos = match.end(1)
|
| potential_sentence = remaining_text[:end_pos].strip()
|
|
|
|
|
| if any(potential_sentence.endswith(abbrev) for abbrev in ABBREVIATIONS):
|
| remaining_text = remaining_text[end_pos:].lstrip()
|
| continue
|
|
|
| complete_sentences.append(potential_sentence)
|
| remaining_text = remaining_text[end_pos:].lstrip()
|
|
|
| return complete_sentences, remaining_text
|
|
|
|
|
| def segment_text_by_pysbd(text: str) -> Tuple[List[str], str]:
|
| """
|
| Segment text into complete sentences and remaining text.
|
| Uses pysbd for supported languages, falls back to regex for others.
|
|
|
| Args:
|
| text: Text to segment into sentences
|
|
|
| Returns:
|
| Tuple[List[str], str]: (list of complete sentences, remaining incomplete text)
|
| """
|
| if not text:
|
| return [], ""
|
|
|
| try:
|
|
|
| lang = detect_language(text)
|
|
|
| if lang is not None:
|
|
|
| segmenter = pysbd.Segmenter(language=lang, clean=False)
|
| sentences = segmenter.segment(text)
|
|
|
| if not sentences:
|
| return [], text
|
|
|
|
|
| complete_sentences = []
|
| for sent in sentences[:-1]:
|
| sent = sent.strip()
|
| if sent:
|
| complete_sentences.append(sent)
|
|
|
|
|
| last_sent = sentences[-1].strip()
|
| if is_complete_sentence(last_sent):
|
| complete_sentences.append(last_sent)
|
| remaining = ""
|
| else:
|
| remaining = last_sent
|
|
|
| else:
|
|
|
| return segment_text_by_regex(text)
|
|
|
| logger.debug(
|
| f"Processed sentences: {complete_sentences}, Remaining: {remaining}"
|
| )
|
| return complete_sentences, remaining
|
|
|
| except Exception as e:
|
| logger.error(f"Error in sentence segmentation: {e}")
|
|
|
| return segment_text_by_regex(text)
|
|
|
|
|
| class TagState(Enum):
|
| """State of a tag in text"""
|
|
|
| START = "start"
|
| INSIDE = "inside"
|
| END = "end"
|
| SELF_CLOSING = "self"
|
| NONE = "none"
|
|
|
|
|
| @dataclass
|
| class TagInfo:
|
| """Information about a tag"""
|
|
|
| name: str
|
| state: TagState
|
|
|
| def __str__(self) -> str:
|
| """String representation of tag info"""
|
| if self.state == TagState.NONE:
|
| return "none"
|
| return f"{self.name}:{self.state.value}"
|
|
|
|
|
| @dataclass
|
| class SentenceWithTags:
|
| """A sentence with its tag information, supporting nested tags"""
|
|
|
| text: str
|
| tags: List[TagInfo]
|
|
|
|
|
| class SentenceDivider:
|
| def __init__(
|
| self,
|
| faster_first_response: bool = True,
|
| segment_method: str = "pysbd",
|
| valid_tags: List[str] = None,
|
| ):
|
| """
|
| Initialize the SentenceDivider.
|
|
|
| Args:
|
| faster_first_response: Whether to split first sentence at commas
|
| segment_method: Method for segmenting sentences
|
| valid_tags: List of valid tag names to detect
|
| """
|
| self.faster_first_response = faster_first_response
|
| self.segment_method = segment_method
|
| self.valid_tags = valid_tags or ["think"]
|
| self._is_first_sentence = True
|
| self._buffer = ""
|
|
|
| self._tag_stack = []
|
|
|
| def _get_current_tags(self) -> List[TagInfo]:
|
| """
|
| Get all current active tags from outermost to innermost.
|
|
|
| Returns:
|
| List[TagInfo]: List of active tags
|
| """
|
| return [TagInfo(tag.name, TagState.INSIDE) for tag in self._tag_stack]
|
|
|
| def _get_current_tag(self) -> Optional[TagInfo]:
|
| """
|
| Get the current innermost active tag.
|
|
|
| Returns:
|
| TagInfo if there's an active tag, None otherwise
|
| """
|
| return self._tag_stack[-1] if self._tag_stack else None
|
|
|
| def _extract_tag(self, text: str) -> Tuple[Optional[TagInfo], str]:
|
| """
|
| Extract the first tag from text if present.
|
| Handles nested tags by maintaining a tag stack.
|
|
|
| Args:
|
| text: Text to check for tags
|
|
|
| Returns:
|
| Tuple of (TagInfo if tag found else None, remaining text)
|
| """
|
|
|
| first_tag = None
|
| first_pos = len(text)
|
| tag_type = None
|
| matched_tag = None
|
|
|
|
|
| for tag in self.valid_tags:
|
| pattern = f"<{tag}/>"
|
| match = re.search(pattern, text)
|
| if match and match.start() < first_pos:
|
| first_pos = match.start()
|
| first_tag = match
|
| tag_type = TagState.SELF_CLOSING
|
| matched_tag = tag
|
|
|
|
|
| for tag in self.valid_tags:
|
| pattern = f"<{tag}>"
|
| match = re.search(pattern, text)
|
| if match and match.start() < first_pos:
|
| first_pos = match.start()
|
| first_tag = match
|
| tag_type = TagState.START
|
| matched_tag = tag
|
|
|
|
|
| for tag in self.valid_tags:
|
| pattern = f"</{tag}>"
|
| match = re.search(pattern, text)
|
| if match and match.start() < first_pos:
|
| first_pos = match.start()
|
| first_tag = match
|
| tag_type = TagState.END
|
| matched_tag = tag
|
|
|
| if not first_tag:
|
| return None, text
|
|
|
|
|
| if tag_type == TagState.START:
|
|
|
| self._tag_stack.append(TagInfo(matched_tag, TagState.START))
|
| elif tag_type == TagState.END:
|
|
|
| if not self._tag_stack or self._tag_stack[-1].name != matched_tag:
|
| logger.warning(f"Mismatched closing tag: {matched_tag}")
|
| else:
|
| self._tag_stack.pop()
|
|
|
| return (TagInfo(matched_tag, tag_type), text[first_tag.end() :].lstrip())
|
|
|
| async def _process_buffer(self) -> AsyncIterator[SentenceWithTags]:
|
| """
|
| Process the current buffer, yielding complete sentences with tags.
|
| This is now an async generator.
|
| It consumes processed parts from self._buffer.
|
| """
|
| processed_something = True
|
| while processed_something:
|
| processed_something = False
|
| original_buffer_len = len(self._buffer)
|
|
|
| if not self._buffer.strip():
|
| break
|
|
|
|
|
| next_tag_pos = len(self._buffer)
|
| tag_pattern_found = None
|
| for tag in self.valid_tags:
|
| patterns = [f"<{tag}>", f"</{tag}>", f"<{tag}/>"]
|
| for pattern in patterns:
|
| pos = self._buffer.find(pattern)
|
| if pos != -1 and pos < next_tag_pos:
|
| next_tag_pos = pos
|
| tag_pattern_found = pattern
|
|
|
| if next_tag_pos == 0:
|
|
|
| tag_info, remaining = self._extract_tag(self._buffer)
|
| if tag_info:
|
| processed_text = self._buffer[
|
| : len(self._buffer) - len(remaining)
|
| ].strip()
|
|
|
| yield SentenceWithTags(text=processed_text, tags=[tag_info])
|
| self._buffer = remaining
|
| processed_something = True
|
| continue
|
|
|
| elif next_tag_pos < len(self._buffer):
|
|
|
| text_before_tag = self._buffer[:next_tag_pos]
|
| current_tags = self._get_current_tags()
|
| processed_segment = ""
|
|
|
|
|
| if contains_end_punctuation(text_before_tag):
|
| sentences, remaining_before = self._segment_text(text_before_tag)
|
| for sentence in sentences:
|
| if sentence.strip():
|
| yield SentenceWithTags(
|
| text=sentence.strip(),
|
| tags=current_tags or [TagInfo("", TagState.NONE)],
|
| )
|
|
|
| processed_segment = text_before_tag
|
| self._buffer = self._buffer[len(processed_segment) :]
|
| processed_something = True
|
| continue
|
|
|
| elif text_before_tag.strip() and tag_pattern_found:
|
|
|
|
|
| yield SentenceWithTags(
|
| text=text_before_tag.strip(),
|
| tags=current_tags or [TagInfo("", TagState.NONE)],
|
| )
|
| self._buffer = self._buffer[len(text_before_tag) :]
|
| processed_something = True
|
| continue
|
|
|
|
|
|
|
| tag_info, remaining_after_tag = self._extract_tag(self._buffer)
|
| if tag_info:
|
| processed_tag_text = self._buffer[
|
| : len(self._buffer) - len(remaining_after_tag)
|
| ].strip()
|
| yield SentenceWithTags(text=processed_tag_text, tags=[tag_info])
|
| self._buffer = remaining_after_tag
|
| processed_something = True
|
| continue
|
|
|
|
|
|
|
| if original_buffer_len > 0:
|
| current_tags = self._get_current_tags()
|
|
|
|
|
| if (
|
| self._is_first_sentence
|
| and self.faster_first_response
|
| and contains_comma(self._buffer)
|
| ):
|
| sentence, remaining = comma_splitter(self._buffer)
|
| if sentence.strip():
|
| yield SentenceWithTags(
|
| text=sentence.strip(),
|
| tags=current_tags or [TagInfo("", TagState.NONE)],
|
| )
|
| self._buffer = remaining
|
| self._is_first_sentence = False
|
| processed_something = True
|
| continue
|
|
|
|
|
| if contains_end_punctuation(self._buffer):
|
| sentences, remaining = self._segment_text(self._buffer)
|
| if sentences:
|
| self._buffer = remaining
|
| self._is_first_sentence = False
|
| processed_something = True
|
| for sentence in sentences:
|
| if sentence.strip():
|
| yield SentenceWithTags(
|
| text=sentence.strip(),
|
| tags=current_tags or [TagInfo("", TagState.NONE)],
|
| )
|
| continue
|
|
|
|
|
| if not processed_something:
|
| break
|
|
|
| async def _flush_buffer(self) -> AsyncIterator[SentenceWithTags]:
|
| """
|
| Process and yield all remaining content in the buffer at the end of the stream.
|
| """
|
| logger.debug(f"Flushing remaining buffer: '{self._buffer}'")
|
|
|
| async for sentence in self._process_buffer():
|
| yield sentence
|
|
|
|
|
| if self._buffer.strip():
|
| logger.debug(
|
| f"Yielding final fragment from buffer: '{self._buffer.strip()}'"
|
| )
|
| current_tags = self._get_current_tags()
|
| yield SentenceWithTags(
|
| text=self._buffer.strip(),
|
| tags=current_tags or [TagInfo("", TagState.NONE)],
|
| )
|
| self._buffer = ""
|
|
|
| async def process_stream(
|
| self, segment_stream: AsyncIterator[Union[str, Dict[str, Any]]]
|
| ) -> AsyncIterator[Union[SentenceWithTags, Dict[str, Any]]]:
|
| """
|
| Process a stream of tokens (strings) and dictionaries.
|
| Yields complete sentences with tags (SentenceWithTags) or dictionaries directly.
|
|
|
| Args:
|
| segment_stream: An async iterator yielding strings or dictionaries.
|
|
|
| Yields:
|
| Union[SentenceWithTags, Dict[str, Any]]: Complete sentences/tags or original dictionaries.
|
| """
|
| self._full_response = []
|
| self.reset()
|
|
|
| async for item in segment_stream:
|
| if isinstance(item, dict):
|
|
|
| async for sentence in self._process_buffer():
|
| self._full_response.append(
|
| sentence.text
|
| )
|
| yield sentence
|
|
|
| yield item
|
| elif isinstance(item, str):
|
| self._buffer += item
|
|
|
| async for sentence in self._process_buffer():
|
| self._full_response.append(
|
| sentence.text
|
| )
|
| yield sentence
|
| else:
|
| logger.warning(
|
| f"SentenceDivider received unexpected type: {type(item)}"
|
| )
|
|
|
|
|
| async for sentence in self._flush_buffer():
|
| self._full_response.append(sentence.text)
|
| yield sentence
|
|
|
| @property
|
| def complete_response(self) -> str:
|
| """Get the complete response accumulated so far"""
|
| return "".join(self._full_response)
|
|
|
| def _segment_text(self, text: str) -> Tuple[List[str], str]:
|
| """Segment text using the configured method"""
|
| if self.segment_method == "regex":
|
| return segment_text_by_regex(text)
|
| return segment_text_by_pysbd(text)
|
|
|
| def reset(self):
|
| """Reset the divider state for a new conversation"""
|
| self._is_first_sentence = True
|
| self._buffer = ""
|
| self._tag_stack = []
|
|
|