File size: 8,176 Bytes
5669b22 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | from typing import AsyncIterator, Tuple, Callable, List, Union, Dict, Any
from functools import wraps
from .output_types import Actions, SentenceOutput, DisplayText
from ..utils.tts_preprocessor import tts_filter as filter_text
from ..live2d_model import Live2dModel
from ..config_manager import TTSPreprocessorConfig
from ..utils.sentence_divider import SentenceDivider
from ..utils.sentence_divider import SentenceWithTags, TagState
from loguru import logger
def sentence_divider(
faster_first_response: bool = True,
segment_method: str = "pysbd",
valid_tags: List[str] = None,
):
"""
Decorator that transforms token stream into sentences with tags
Args:
faster_first_response: bool - Whether to enable faster first response
segment_method: str - Method for sentence segmentation
valid_tags: List[str] - List of valid tags to process
"""
def decorator(
func: Callable[
..., AsyncIterator[Union[str, Dict[str, Any]]]
], # Expects str or dict
) -> Callable[
..., AsyncIterator[Union[SentenceWithTags, Dict[str, Any]]]
]: # Yields SentenceWithTags or dict
@wraps(func)
async def wrapper(
*args, **kwargs
) -> AsyncIterator[Union[SentenceWithTags, Dict[str, Any]]]:
divider = SentenceDivider(
faster_first_response=faster_first_response,
segment_method=segment_method,
valid_tags=valid_tags or [],
)
stream_from_func = func(*args, **kwargs)
# Process the mixed stream using the updated SentenceDivider
async for item in divider.process_stream(stream_from_func):
if isinstance(item, SentenceWithTags):
logger.debug(f"sentence_divider yielding sentence: {item}")
elif isinstance(item, dict):
logger.debug(f"sentence_divider yielding dict: {item}")
yield item # Yield either SentenceWithTags or dict
# Flushing is handled within divider.process_stream
return wrapper
return decorator
def actions_extractor(live2d_model: Live2dModel):
"""
Decorator that extracts actions from sentences, passing through dicts.
"""
def decorator(
func: Callable[
..., AsyncIterator[Union[SentenceWithTags, Dict[str, Any]]]
], # Input type hint
) -> Callable[
..., AsyncIterator[Union[Tuple[SentenceWithTags, Actions], Dict[str, Any]]]
]: # Output type hint
@wraps(func)
async def wrapper(
*args, **kwargs
) -> AsyncIterator[
Union[Tuple[SentenceWithTags, Actions], Dict[str, Any]]
]: # Yield type hint
stream = func(*args, **kwargs)
async for item in stream:
if isinstance(item, SentenceWithTags):
sentence = item
actions = Actions()
# Only extract emotions for non-tag text
if not any(
tag.state in [TagState.START, TagState.END]
for tag in sentence.tags
):
expressions = live2d_model.extract_emotion(sentence.text)
if expressions:
actions.expressions = expressions
yield sentence, actions # Yield the tuple
elif isinstance(item, dict):
# Pass through dictionaries
yield item
else:
logger.warning(
f"actions_extractor received unexpected type: {type(item)}"
)
return wrapper
return decorator
def display_processor():
"""
Decorator that processes text for display, passing through dicts.
"""
def decorator(
func: Callable[
..., AsyncIterator[Union[Tuple[SentenceWithTags, Actions], Dict[str, Any]]]
], # Input type hint
) -> Callable[
...,
AsyncIterator[
Union[Tuple[SentenceWithTags, DisplayText, Actions], Dict[str, Any]]
],
]: # Output type hint
@wraps(func)
async def wrapper(
*args, **kwargs
) -> AsyncIterator[
Union[Tuple[SentenceWithTags, DisplayText, Actions], Dict[str, Any]]
]: # Yield type hint
stream = func(*args, **kwargs)
async for item in stream:
if (
isinstance(item, tuple)
and len(item) == 2
and isinstance(item[0], SentenceWithTags)
):
sentence, actions = item
text = sentence.text
# Handle think tag states
for tag in sentence.tags:
if tag.name == "think":
if tag.state == TagState.START:
text = "("
elif tag.state == TagState.END:
text = ")"
display = DisplayText(text=text) # Simplified DisplayText creation
yield sentence, display, actions # Yield the tuple
elif isinstance(item, dict):
# Pass through dictionaries
yield item
else:
logger.warning(
f"display_processor received unexpected type: {type(item)}"
)
return wrapper
return decorator
def tts_filter(
tts_preprocessor_config: TTSPreprocessorConfig = None,
):
"""
Decorator that filters text for TTS, passing through dicts.
Skips TTS for think tag content.
"""
def decorator(
func: Callable[
...,
AsyncIterator[
Union[Tuple[SentenceWithTags, DisplayText, Actions], Dict[str, Any]]
],
], # Input type hint
) -> Callable[
..., AsyncIterator[Union[SentenceOutput, Dict[str, Any]]]
]: # Output type hint
@wraps(func)
async def wrapper(
*args, **kwargs
) -> AsyncIterator[Union[SentenceOutput, Dict[str, Any]]]: # Yield type hint
stream = func(*args, **kwargs)
config = tts_preprocessor_config or TTSPreprocessorConfig()
async for item in stream:
if (
isinstance(item, tuple)
and len(item) == 3
and isinstance(item[1], DisplayText)
):
sentence, display, actions = item
if any(tag.name == "think" for tag in sentence.tags):
tts = ""
else:
tts = filter_text(
text=display.text,
remove_special_char=config.remove_special_char,
ignore_brackets=config.ignore_brackets,
ignore_parentheses=config.ignore_parentheses,
ignore_asterisks=config.ignore_asterisks,
ignore_angle_brackets=config.ignore_angle_brackets,
)
logger.debug(f"[{display.name}] display: {display.text}")
logger.debug(f"[{display.name}] tts: {tts}")
yield SentenceOutput(
display_text=display,
tts_text=tts,
actions=actions,
)
elif isinstance(item, dict):
# Pass through dictionaries
yield item
else:
logger.warning(f"tts_filter received unexpected type: {type(item)}")
return wrapper
return decorator
|