| | from simuleval.utils.agent import build_system_from_dir |
| | from typing import Any, List, Optional, Tuple, Union |
| | import numpy as np |
| | import soundfile |
| | import io |
| | import asyncio |
| | from simuleval.agents.pipeline import TreeAgentPipeline |
| | from simuleval.agents.states import AgentStates |
| | from simuleval.data.segments import Segment, EmptySegment, SpeechSegment |
| | import threading |
| | import math |
| | import logging |
| | import sys |
| | from pathlib import Path |
| | import time |
| | from g2p_en import G2p |
| | import torch |
| | import traceback |
| | import time |
| | import random |
| | import colorlog |
| |
|
| | from .speech_and_text_output import SpeechAndTextOutput |
| |
|
| | MODEL_SAMPLE_RATE = 16_000 |
| |
|
| | logger = logging.getLogger(__name__) |
| | |
| | handler = colorlog.StreamHandler(stream=sys.stdout) |
| | formatter = colorlog.ColoredFormatter( |
| | "%(log_color)s[%(asctime)s][%(levelname)s][%(module)s]:%(reset)s %(message)s", |
| | reset=True, |
| | log_colors={ |
| | "DEBUG": "cyan", |
| | "INFO": "green", |
| | "WARNING": "yellow", |
| | "ERROR": "red", |
| | "CRITICAL": "red,bg_white", |
| | }, |
| | ) |
| | handler.setFormatter(formatter) |
| | logger.addHandler(handler) |
| | logger.setLevel(logging.WARNING) |
| |
|
| |
|
| | class OutputSegments: |
| | def __init__(self, segments: Union[List[Segment], Segment]): |
| | if isinstance(segments, Segment): |
| | segments = [segments] |
| | self.segments: List[Segment] = [s for s in segments] |
| |
|
| | @property |
| | def is_empty(self): |
| | return all(segment.is_empty for segment in self.segments) |
| |
|
| | @property |
| | def finished(self): |
| | return all(segment.finished for segment in self.segments) |
| |
|
| | def compute_length(self, g2p): |
| | lengths = [] |
| | for segment in self.segments: |
| | if segment.data_type == "text": |
| | lengths.append(len([x for x in g2p(segment.content) if x != " "])) |
| | elif segment.data_type == "speech": |
| | lengths.append(len(segment.content) / MODEL_SAMPLE_RATE) |
| | elif isinstance(segment, EmptySegment): |
| | continue |
| | else: |
| | logger.warning( |
| | f"Unexpected data_type: {segment.data_type} not in 'speech', 'text'" |
| | ) |
| | return max(lengths) |
| |
|
| | @classmethod |
| | def join_output_buffer( |
| | cls, buffer: List[List[Segment]], output: SpeechAndTextOutput |
| | ): |
| | num_segments = len(buffer[0]) |
| | for i in range(num_segments): |
| | segment_list = [ |
| | buffer[j][i] |
| | for j in range(len(buffer)) |
| | if buffer[j][i].data_type is not None |
| | ] |
| | if len(segment_list) == 0: |
| | continue |
| | if len(set(segment.data_type for segment in segment_list)) != 1: |
| | logger.warning( |
| | f"Data type mismatch at {i}: {set(segment.data_type for segment in segment_list)}" |
| | ) |
| | continue |
| | data_type = segment_list[0].data_type |
| | if data_type == "text": |
| | if output.text is not None: |
| | logger.warning("Multiple text outputs, overwriting!") |
| | output.text = " ".join([segment.content for segment in segment_list]) |
| | elif data_type == "speech": |
| | if output.speech_samples is not None: |
| | logger.warning("Multiple speech outputs, overwriting!") |
| | speech_out = [] |
| | for segment in segment_list: |
| | speech_out += segment.content |
| | output.speech_samples = speech_out |
| | output.speech_sample_rate = segment.sample_rate |
| | elif isinstance(segment_list[0], EmptySegment): |
| | continue |
| | else: |
| | logger.warning( |
| | f"Invalid output buffer data type: {data_type}, expected 'speech' or 'text" |
| | ) |
| |
|
| | return output |
| |
|
| | def __repr__(self) -> str: |
| | repr_str = str(self.segments) |
| | return f"{self.__class__.__name__}(\n\t{repr_str}\n)" |
| |
|
| |
|
| | class SimulevalTranscoder: |
| | def __init__(self, agent, sample_rate, debug, buffer_limit): |
| | self.agent = agent.agent |
| | self.has_expressive = agent.has_expressive |
| | self.input_queue = asyncio.Queue() |
| | self.output_queue = asyncio.Queue() |
| | self.states = self.agent.build_states() |
| | if debug: |
| | self.get_states_root().debug = True |
| | self.incoming_sample_rate = sample_rate |
| | self.close = False |
| | self.g2p = G2p() |
| |
|
| | |
| | self.output_buffer_idle_ms = 5000 |
| | self.output_buffer_size_limit = ( |
| | buffer_limit |
| | ) |
| | self.output_buffer_cur_size = 0 |
| | self.output_buffer: List[List[Segment]] = [] |
| | self.speech_output_sample_rate = None |
| |
|
| | self.last_output_ts = time.time() * 1000 |
| | self.timeout_ms = ( |
| | 30000 |
| | ) |
| | self.first_input_ts = None |
| | self.first_output_ts = None |
| | self.debug = debug |
| | self.debug_ts = f"{time.time()}_{random.randint(1000, 9999)}" |
| | if self.debug: |
| | debug_folder = Path(__file__).resolve().parent.parent / "debug" |
| | self.test_incoming_wav = soundfile.SoundFile( |
| | debug_folder / f"{self.debug_ts}_test_incoming.wav", |
| | mode="w+", |
| | format="WAV", |
| | subtype="PCM_16", |
| | samplerate=self.incoming_sample_rate, |
| | channels=1, |
| | ) |
| | self.get_states_root().test_input_segments_wav = soundfile.SoundFile( |
| | debug_folder / f"{self.debug_ts}_test_input_segments.wav", |
| | mode="w+", |
| | format="WAV", |
| | samplerate=MODEL_SAMPLE_RATE, |
| | channels=1, |
| | ) |
| |
|
| | def get_states_root(self) -> AgentStates: |
| | if isinstance(self.agent, TreeAgentPipeline): |
| | |
| | return self.states[self.agent.source_module] |
| | else: |
| | |
| | return self.states[0] |
| |
|
| | def reset_states(self): |
| | if isinstance(self.agent, TreeAgentPipeline): |
| | states_iter = self.states.values() |
| | else: |
| | states_iter = self.states |
| | for state in states_iter: |
| | state.reset() |
| |
|
| | def debug_log(self, *args): |
| | if self.debug: |
| | logger.info(*args) |
| |
|
| | @classmethod |
| | def build_agent(cls, model_path, config_name): |
| | logger.info(f"Building simuleval agent: {model_path}, {config_name}") |
| | agent = build_system_from_dir( |
| | Path(__file__).resolve().parent.parent / f"models/{model_path}", |
| | config_name=config_name, |
| | ) |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | agent.to(device, fp16=True) |
| | logger.info( |
| | f"Successfully built simuleval agent {model_path} on device {device}" |
| | ) |
| |
|
| | return agent |
| |
|
| | def process_incoming_bytes(self, incoming_bytes, dynamic_config): |
| | |
| | segment, sr = self._preprocess_wav(incoming_bytes) |
| | segment = SpeechSegment( |
| | content=segment, |
| | sample_rate=sr, |
| | tgt_lang=dynamic_config.get("targetLanguage"), |
| | config=dynamic_config, |
| | ) |
| | if dynamic_config.get("expressive") is True and self.has_expressive is False: |
| | logger.warning( |
| | "Passing 'expressive' but the agent does not support expressive output!" |
| | ) |
| | |
| | self.input_queue.put_nowait(segment) |
| |
|
| | def get_input_segment(self): |
| | if self.input_queue.empty(): |
| | return None |
| | chunk = self.input_queue.get_nowait() |
| | self.input_queue.task_done() |
| | return chunk |
| |
|
| | def convert_waveform( |
| | self, |
| | waveform: Union[np.ndarray, torch.Tensor], |
| | sample_rate: int, |
| | normalize_volume: bool = False, |
| | to_mono: bool = False, |
| | to_sample_rate: Optional[int] = None, |
| | ) -> Tuple[Union[np.ndarray, torch.Tensor], int]: |
| | """convert a waveform: |
| | - to a target sample rate |
| | - from multi-channel to mono channel |
| | - volume normalization |
| | |
| | Args: |
| | waveform (numpy.ndarray or torch.Tensor): 2D original waveform |
| | (channels x length) |
| | sample_rate (int): original sample rate |
| | normalize_volume (bool): perform volume normalization |
| | to_mono (bool): convert to mono channel if having multiple channels |
| | to_sample_rate (Optional[int]): target sample rate |
| | Returns: |
| | waveform (numpy.ndarray): converted 2D waveform (channels x length) |
| | sample_rate (float): target sample rate |
| | """ |
| | try: |
| | import torchaudio.sox_effects as ta_sox |
| | except ImportError: |
| | raise ImportError("Please install torchaudio: pip install torchaudio") |
| |
|
| | effects = [] |
| | if normalize_volume: |
| | effects.append(["gain", "-n"]) |
| | if to_sample_rate is not None and to_sample_rate != sample_rate: |
| | effects.append(["rate", f"{to_sample_rate}"]) |
| | if to_mono and waveform.shape[0] > 1: |
| | effects.append(["channels", "1"]) |
| | if len(effects) > 0: |
| | is_np_input = isinstance(waveform, np.ndarray) |
| | _waveform = torch.from_numpy(waveform) if is_np_input else waveform |
| | converted, converted_sample_rate = ta_sox.apply_effects_tensor( |
| | _waveform, sample_rate, effects |
| | ) |
| | if is_np_input: |
| | converted = converted.numpy() |
| | return converted, converted_sample_rate |
| | return waveform, sample_rate |
| |
|
| | def _preprocess_wav(self, data: Any) -> Tuple[np.ndarray, int]: |
| | segment, sample_rate = soundfile.read( |
| | io.BytesIO(data), |
| | dtype="float32", |
| | always_2d=True, |
| | frames=-1, |
| | start=0, |
| | format="RAW", |
| | subtype="PCM_16", |
| | samplerate=self.incoming_sample_rate, |
| | channels=1, |
| | ) |
| | if self.debug: |
| | self.test_incoming_wav.seek(0, soundfile.SEEK_END) |
| | self.test_incoming_wav.write(segment) |
| |
|
| | segment = segment.T |
| | segment, new_sample_rate = self.convert_waveform( |
| | segment, |
| | sample_rate, |
| | normalize_volume=False, |
| | to_mono=True, |
| | to_sample_rate=MODEL_SAMPLE_RATE, |
| | ) |
| |
|
| | assert MODEL_SAMPLE_RATE == new_sample_rate |
| | segment = segment.squeeze(axis=0) |
| | return segment, new_sample_rate |
| |
|
| | def process_pipeline_impl(self, input_segment): |
| | try: |
| | with torch.no_grad(): |
| | output_segment = OutputSegments( |
| | self.agent.pushpop(input_segment, self.states) |
| | ) |
| | if ( |
| | self.get_states_root().first_input_ts is not None |
| | and self.first_input_ts is None |
| | ): |
| | |
| | self.first_input_ts = self.get_states_root().first_input_ts |
| |
|
| | if not output_segment.is_empty: |
| | self.output_queue.put_nowait(output_segment) |
| |
|
| | if output_segment.finished: |
| | self.debug_log("OUTPUT SEGMENT IS FINISHED. Resetting states.") |
| |
|
| | self.reset_states() |
| |
|
| | if self.debug: |
| | |
| | |
| | self.get_states_root().debug = True |
| | except Exception as e: |
| | logger.error(f"Got exception while processing pipeline: {e}") |
| | traceback.print_exc() |
| | return input_segment |
| |
|
| | def process_pipeline_loop(self): |
| | if self.close: |
| | return |
| |
|
| | self.debug_log("processing_pipeline") |
| | while not self.close: |
| | input_segment = self.get_input_segment() |
| | if input_segment is None: |
| | if self.get_states_root().is_fresh_state: |
| | time.sleep(0.3) |
| | else: |
| | time.sleep(0.03) |
| | continue |
| | self.process_pipeline_impl(input_segment) |
| | self.debug_log("finished processing_pipeline") |
| |
|
| | def process_pipeline_once(self): |
| | if self.close: |
| | return |
| |
|
| | self.debug_log("processing pipeline once") |
| | input_segment = self.get_input_segment() |
| | if input_segment is None: |
| | return |
| | self.process_pipeline_impl(input_segment) |
| | self.debug_log("finished processing_pipeline_once") |
| |
|
| | def get_output_segment(self): |
| | if self.output_queue.empty(): |
| | return None |
| |
|
| | output_chunk = self.output_queue.get_nowait() |
| | self.output_queue.task_done() |
| | return output_chunk |
| |
|
| | def start(self): |
| | self.debug_log("starting transcoder in a thread") |
| | threading.Thread(target=self.process_pipeline_loop).start() |
| |
|
| | def first_translation_time(self): |
| | return round((self.first_output_ts - self.first_input_ts) / 1000, 2) |
| |
|
| | def get_buffered_output(self) -> SpeechAndTextOutput: |
| | now = time.time() * 1000 |
| | self.debug_log(f"get_buffered_output queue size: {self.output_queue.qsize()}") |
| | while not self.output_queue.empty(): |
| | tmp_out = self.get_output_segment() |
| | if tmp_out and tmp_out.compute_length(self.g2p) > 0: |
| | if len(self.output_buffer) == 0: |
| | self.last_output_ts = now |
| | self._populate_output_buffer(tmp_out) |
| | self._increment_output_buffer_size(tmp_out) |
| |
|
| | if tmp_out.finished: |
| | self.debug_log("tmp_out.finished") |
| | res = self._gather_output_buffer_data(final=True) |
| | self.debug_log(f"gathered output data: {res}") |
| | self.output_buffer = [] |
| | self.increment_output_buffer_size = 0 |
| | self.last_output_ts = now |
| | self.first_output_ts = now |
| | return res |
| | else: |
| | self.debug_log("tmp_out.compute_length is not > 0") |
| |
|
| | if len(self.output_buffer) > 0 and ( |
| | now - self.last_output_ts >= self.output_buffer_idle_ms |
| | or self.output_buffer_cur_size >= self.output_buffer_size_limit |
| | ): |
| | self.debug_log( |
| | "[get_buffered_output] output_buffer is not empty. getting res to return." |
| | ) |
| | self.last_output_ts = now |
| | res = self._gather_output_buffer_data(final=False) |
| | self.debug_log(f"gathered output data: {res}") |
| | self.output_buffer = [] |
| | self.output_buffer_phoneme_count = 0 |
| | self.first_output_ts = now |
| | return res |
| | else: |
| | self.debug_log("[get_buffered_output] output_buffer is empty...") |
| | return None |
| |
|
| | def _gather_output_buffer_data(self, final): |
| | output = SpeechAndTextOutput() |
| | output.final = final |
| | output = OutputSegments.join_output_buffer(self.output_buffer, output) |
| | return output |
| |
|
| | def _increment_output_buffer_size(self, segment: OutputSegments): |
| | self.output_buffer_cur_size += segment.compute_length(self.g2p) |
| |
|
| | def _populate_output_buffer(self, segment: OutputSegments): |
| | self.output_buffer.append(segment.segments) |
| |
|
| | def _compute_phoneme_count(self, string: str) -> int: |
| | return len([x for x in self.g2p(string) if x != " "]) |
| |
|