Spaces:
Paused
Paused
| # base seamless imports | |
| # --------------------------------- | |
| import io | |
| import json | |
| import matplotlib as mpl | |
| import matplotlib.pyplot as plt | |
| import mmap | |
| import numpy as np | |
| import soundfile | |
| import torchaudio | |
| import torch | |
| from pydub import AudioSegment | |
| # --------------------------------- | |
| # seamless-streaming specific imports | |
| # --------------------------------- | |
| import math | |
| from simuleval.data.segments import SpeechSegment, EmptySegment | |
| from seamless_communication.streaming.agents.seamless_streaming_s2st import ( | |
| SeamlessStreamingS2STVADAgent, | |
| ) | |
| from simuleval.utils.arguments import cli_argument_list | |
| from simuleval import options | |
| from typing import Union, List | |
| from simuleval.data.segments import Segment, TextSegment | |
| from simuleval.agents.pipeline import TreeAgentPipeline | |
| from simuleval.agents.states import AgentStates | |
| # --------------------------------- | |
| # seamless setup | |
| # source: https://colab.research.google.com/github/kauterry/seamless_communication/blob/main/Seamless_Tutorial.ipynb? | |
| SAMPLE_RATE = 16000 | |
| # PM - THis class is used to simulate the audio frontend in the seamless streaming pipeline | |
| # need to replace this with the actual audio frontend | |
| # TODO: replacement class that takes in PCM-16 bytes and returns SpeechSegment | |
| class AudioFrontEnd: | |
| def __init__(self, wav_file, segment_size) -> None: | |
| self.samples, self.sample_rate = soundfile.read(wav_file) | |
| print(self.sample_rate, "sample rate") | |
| assert self.sample_rate == SAMPLE_RATE | |
| # print(len(self.samples), self.samples[:100]) | |
| self.samples = self.samples # .tolist() | |
| self.segment_size = segment_size | |
| self.step = 0 | |
| def send_segment(self): | |
| """ | |
| This is the front-end logic in simuleval instance.py | |
| """ | |
| num_samples = math.ceil(self.segment_size / 1000 * self.sample_rate) | |
| if self.step < len(self.samples): | |
| if self.step + num_samples >= len(self.samples): | |
| samples = self.samples[self.step :] | |
| is_finished = True | |
| else: | |
| samples = self.samples[self.step : self.step + num_samples] | |
| is_finished = False | |
| self.samples = self.samples[self.step:] | |
| self.step = min(self.step + num_samples, len(self.samples)) | |
| segment = SpeechSegment( | |
| content=samples, | |
| sample_rate=self.sample_rate, | |
| finished=is_finished, | |
| ) | |
| else: | |
| # Finish reading this audio | |
| segment = EmptySegment( | |
| finished=True, | |
| ) | |
| self.step = 0 | |
| self.samples = [] | |
| return segment | |
| # samples = self.samples[:num_samples] | |
| # self.samples = self.samples[num_samples:] | |
| # segment = SpeechSegment( | |
| # content=samples, | |
| # sample_rate=self.sample_rate, | |
| # finished=False, | |
| # ) | |
| def add_segments(self, wav): | |
| new_samples, _ = soundfile.read(wav) | |
| self.samples = np.concatenate((self.samples, new_samples)) | |
| class OutputSegments: | |
| def __init__(self, segments: Union[List[Segment], Segment]): | |
| if isinstance(segments, Segment): | |
| segments = [segments] | |
| self.segments: List[Segment] = [s for s in segments] | |
| def is_empty(self): | |
| return all(segment.is_empty for segment in self.segments) | |
| def finished(self): | |
| return all(segment.finished for segment in self.segments) | |
| def get_audiosegment(samples, sr): | |
| b = io.BytesIO() | |
| soundfile.write(b, samples, samplerate=sr, format="wav") | |
| b.seek(0) | |
| return AudioSegment.from_file(b) | |
| def reset_states(system, states): | |
| if isinstance(system, TreeAgentPipeline): | |
| states_iter = states.values() | |
| else: | |
| states_iter = states | |
| for state in states_iter: | |
| state.reset() | |
| def get_states_root(system, states) -> AgentStates: | |
| if isinstance(system, TreeAgentPipeline): | |
| # self.states is a dict | |
| return states[system.source_module] | |
| else: | |
| # self.states is a list | |
| return system.states[0] | |
| def build_streaming_system(model_configs, agent_class): | |
| parser = options.general_parser() | |
| parser.add_argument("-f", "--f", help="a dummy argument to fool ipython", default="1") | |
| agent_class.add_args(parser) | |
| args, _ = parser.parse_known_args(cli_argument_list(model_configs)) | |
| system = agent_class.from_args(args) | |
| return system | |
| def run_streaming_inference(system, audio_frontend, system_states, tgt_lang): | |
| # NOTE: Here for visualization, we calculate delays offset from audio | |
| # *BEFORE* VAD segmentation. | |
| # In contrast for SimulEval evaluation, we assume audios are pre-segmented, | |
| # and Average Lagging, End Offset metrics are based on those pre-segmented audios. | |
| # Thus, delays here are *NOT* comparable to SimulEval per-segment delays | |
| delays = {"s2st": [], "s2tt": []} | |
| prediction_lists = {"s2st": [], "s2tt": []} | |
| speech_durations = [] | |
| curr_delay = 0 | |
| target_sample_rate = None | |
| while True: | |
| input_segment = audio_frontend.send_segment() | |
| input_segment.tgt_lang = tgt_lang | |
| curr_delay += len(input_segment.content) / SAMPLE_RATE * 1000 | |
| if input_segment.finished: | |
| # a hack, we expect a real stream to end with silence | |
| get_states_root(system, system_states).source_finished = True | |
| # Translation happens here | |
| if isinstance(input_segment, EmptySegment): | |
| return None, None, None, None | |
| output_segments = OutputSegments(system.pushpop(input_segment, system_states)) | |
| if not output_segments.is_empty: | |
| for segment in output_segments.segments: | |
| # NOTE: another difference from SimulEval evaluation - | |
| # delays are accumulated per-token | |
| if isinstance(segment, SpeechSegment): | |
| pred_duration = 1000 * len(segment.content) / segment.sample_rate | |
| speech_durations.append(pred_duration) | |
| delays["s2st"].append(curr_delay) | |
| prediction_lists["s2st"].append(segment.content) | |
| target_sample_rate = segment.sample_rate | |
| elif isinstance(segment, TextSegment): | |
| delays["s2tt"].append(curr_delay) | |
| prediction_lists["s2tt"].append(segment.content) | |
| print(curr_delay, segment.content) | |
| if output_segments.finished: | |
| reset_states(system, system_states) | |
| if input_segment.finished: | |
| # an assumption of SimulEval agents - | |
| # once source_finished=True, generate until output translation is finished | |
| break | |
| return delays, prediction_lists, speech_durations, target_sample_rate | |
| def get_s2st_delayed_targets(delays, target_sample_rate, prediction_lists, speech_durations): | |
| # get calculate intervals + durations for s2st | |
| intervals = [] | |
| start = prev_end = prediction_offset = delays["s2st"][0] | |
| target_samples = [0.0] * int(target_sample_rate * prediction_offset / 1000) | |
| for i, delay in enumerate(delays["s2st"]): | |
| start = max(prev_end, delay) | |
| if start > prev_end: | |
| # Wait source speech, add discontinuity with silence | |
| target_samples += [0.0] * int( | |
| target_sample_rate * (start - prev_end) / 1000 | |
| ) | |
| target_samples += prediction_lists["s2st"][i] | |
| duration = speech_durations[i] | |
| prev_end = start + duration | |
| intervals.append([start, duration]) | |
| return target_samples, intervals | |