mana-tts / sentence_splitter.py
abreza's picture
feat: improved number handling and audio processing
da2ee9a
import re
from typing import List
class PersianSentenceSplitter:
def __init__(self, max_chars: int = 200, min_chars: int = 50):
self.max_chars = max_chars
self.min_chars = min_chars
self.sentence_endings = r'[.!?؟۔]'
self.weak_boundaries = r'[،,;؛]'
def clean_text(self, text: str) -> str:
text = re.sub(r'\s+', ' ', text)
text = text.replace('_', '\u200c')
text = text.replace('ك', 'ک').replace('ي', 'ی')
persian_digits = '۰۱۲۳۴۵۶۷۸۹'
english_digits = '0123456789'
digit_map = str.maketrans(persian_digits, english_digits)
text = text.translate(digit_map)
arabic_digits = '٠١٢٣٤٥٦٧٨٩'
arabic_map = str.maketrans(arabic_digits, english_digits)
text = text.translate(arabic_map)
return text.strip()
def split_by_punctuation(self, text: str) -> List[str]:
segments = re.split(f'({self.sentence_endings})', text)
sentences = []
for i in range(0, len(segments) - 1, 2):
if i + 1 < len(segments):
sentence = segments[i] + segments[i + 1]
else:
sentence = segments[i]
sentence = sentence.strip()
if sentence:
sentences.append(sentence)
if len(segments) % 2 == 1 and segments[-1].strip():
sentences.append(segments[-1].strip())
return sentences
def split_long_sentence(self, sentence: str) -> List[str]:
if len(sentence) <= self.max_chars:
return [sentence]
chunks = []
current_chunk = ""
parts = re.split(f'({self.weak_boundaries})', sentence)
for i in range(0, len(parts)):
part = parts[i]
if len(current_chunk + part) > self.max_chars and current_chunk:
chunks.append(current_chunk.strip())
current_chunk = part
else:
current_chunk += part
if current_chunk.strip():
chunks.append(current_chunk.strip())
final_chunks = []
for chunk in chunks:
if len(chunk) > self.max_chars:
final_chunks.extend(self.force_split_by_words(chunk))
else:
final_chunks.append(chunk)
return final_chunks
def force_split_by_words(self, text: str) -> List[str]:
words = text.split()
chunks = []
current_chunk = []
current_length = 0
for word in words:
word_length = len(word) + 1 # +1 for space
if current_length + word_length > self.max_chars and current_chunk:
chunks.append(' '.join(current_chunk))
current_chunk = [word]
current_length = word_length
else:
current_chunk.append(word)
current_length += word_length
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks
def split(self, text: str) -> List[str]:
text = self.clean_text(text)
if not text:
return []
if len(text) <= self.max_chars:
return [text]
sentences = self.split_by_punctuation(text)
final_segments = []
for sentence in sentences:
if len(sentence) > self.max_chars:
final_segments.extend(self.split_long_sentence(sentence))
else:
final_segments.append(sentence)
final_segments = [seg.strip() for seg in final_segments if seg.strip()]
return final_segments