Update with temp work
Browse files- app.py +6 -1
- internal_demo_simuleval_transcoder.py +272 -0
- requirements.txt +16 -2
- simuleval_transcoder.py +178 -0
app.py
CHANGED
|
@@ -10,6 +10,8 @@ from seamless_communication.models.inference.translator import Translator
|
|
| 10 |
|
| 11 |
|
| 12 |
from m4t_app import *
|
|
|
|
|
|
|
| 13 |
|
| 14 |
from pydub import AudioSegment
|
| 15 |
import time
|
|
@@ -19,6 +21,7 @@ from time import sleep
|
|
| 19 |
|
| 20 |
USE_M4T = True
|
| 21 |
|
|
|
|
| 22 |
|
| 23 |
def translate_audio_file_segment(audio_file):
|
| 24 |
print("translate_m4t state")
|
|
@@ -90,7 +93,9 @@ def blocks():
|
|
| 90 |
)
|
| 91 |
|
| 92 |
most_recent_input_audio_segment = gr.Audio(
|
| 93 |
-
label="Recent Input Audio Segment segments",
|
|
|
|
|
|
|
| 94 |
)
|
| 95 |
# TODO: Should add combined input audio segments...
|
| 96 |
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
from m4t_app import *
|
| 13 |
+
from simuleval_transcoder import *
|
| 14 |
+
# from simuleval_transcoder import *
|
| 15 |
|
| 16 |
from pydub import AudioSegment
|
| 17 |
import time
|
|
|
|
| 21 |
|
| 22 |
USE_M4T = True
|
| 23 |
|
| 24 |
+
Transcoder = SimulevalTranscoder()
|
| 25 |
|
| 26 |
def translate_audio_file_segment(audio_file):
|
| 27 |
print("translate_m4t state")
|
|
|
|
| 93 |
)
|
| 94 |
|
| 95 |
most_recent_input_audio_segment = gr.Audio(
|
| 96 |
+
label="Recent Input Audio Segment segments",
|
| 97 |
+
format="bytes",
|
| 98 |
+
streaming=True
|
| 99 |
)
|
| 100 |
# TODO: Should add combined input audio segments...
|
| 101 |
|
internal_demo_simuleval_transcoder.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from simuleval.utils.agent import build_system_from_dir
|
| 2 |
+
from typing import Any, Tuple
|
| 3 |
+
import numpy as np
|
| 4 |
+
import soundfile
|
| 5 |
+
from fairseq.data.audio.audio_utils import convert_waveform
|
| 6 |
+
import io
|
| 7 |
+
import asyncio
|
| 8 |
+
from simuleval.data.segments import SpeechSegment, EmptySegment
|
| 9 |
+
import threading
|
| 10 |
+
import math
|
| 11 |
+
import logging
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import time
|
| 15 |
+
from g2p_en import G2p
|
| 16 |
+
import torch
|
| 17 |
+
import traceback
|
| 18 |
+
import time
|
| 19 |
+
import random
|
| 20 |
+
|
| 21 |
+
from .speech_and_text_output import SpeechAndTextOutput
|
| 22 |
+
|
| 23 |
+
MODEL_SAMPLE_RATE = 16_000
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger()
|
| 26 |
+
logger.addHandler(logging.StreamHandler(sys.stdout))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SimulevalTranscoder:
|
| 30 |
+
def __init__(self, agent, sample_rate, debug, buffer_limit):
|
| 31 |
+
self.agent = agent
|
| 32 |
+
self.input_queue = asyncio.Queue()
|
| 33 |
+
self.output_queue = asyncio.Queue()
|
| 34 |
+
self.states = self.agent.build_states()
|
| 35 |
+
if debug:
|
| 36 |
+
self.states[0].debug = True
|
| 37 |
+
self.incoming_sample_rate = sample_rate
|
| 38 |
+
self.close = False
|
| 39 |
+
self.g2p = G2p()
|
| 40 |
+
|
| 41 |
+
# buffer all outgoing translations within this amount of time
|
| 42 |
+
self.output_buffer_idle_ms = 5000
|
| 43 |
+
self.output_buffer_size_limit = (
|
| 44 |
+
buffer_limit # phonemes for text, seconds for speech
|
| 45 |
+
)
|
| 46 |
+
self.output_buffer_cur_size = 0
|
| 47 |
+
self.output_buffer = []
|
| 48 |
+
self.speech_output_sample_rate = None
|
| 49 |
+
|
| 50 |
+
self.last_output_ts = time.time() * 1000
|
| 51 |
+
self.timeout_ms = (
|
| 52 |
+
30000 # close the transcoder thread after this amount of silence
|
| 53 |
+
)
|
| 54 |
+
self.first_input_ts = None
|
| 55 |
+
self.first_output_ts = None
|
| 56 |
+
self.output_data_type = None # speech or text
|
| 57 |
+
self.debug = debug
|
| 58 |
+
self.debug_ts = f"{time.time()}_{random.randint(1000, 9999)}"
|
| 59 |
+
if self.debug:
|
| 60 |
+
debug_folder = Path(__file__).resolve().parent.parent / "debug"
|
| 61 |
+
self.test_incoming_wav = soundfile.SoundFile(
|
| 62 |
+
debug_folder / f"{self.debug_ts}_test_incoming.wav",
|
| 63 |
+
mode="w+",
|
| 64 |
+
format="WAV",
|
| 65 |
+
subtype="PCM_16",
|
| 66 |
+
samplerate=self.incoming_sample_rate,
|
| 67 |
+
channels=1,
|
| 68 |
+
)
|
| 69 |
+
self.states[0].test_input_segments_wav = soundfile.SoundFile(
|
| 70 |
+
debug_folder / f"{self.debug_ts}_test_input_segments.wav",
|
| 71 |
+
mode="w+",
|
| 72 |
+
format="WAV",
|
| 73 |
+
samplerate=MODEL_SAMPLE_RATE,
|
| 74 |
+
channels=1,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def debug_log(self, *args):
|
| 78 |
+
if self.debug:
|
| 79 |
+
logger.info(*args)
|
| 80 |
+
|
| 81 |
+
@classmethod
|
| 82 |
+
def build_agent(cls, model_path):
|
| 83 |
+
logger.info(f"Building simuleval agent: {model_path}")
|
| 84 |
+
agent = build_system_from_dir(
|
| 85 |
+
Path(__file__).resolve().parent.parent / f"models/{model_path}",
|
| 86 |
+
config_name="vad_main.yaml",
|
| 87 |
+
)
|
| 88 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 89 |
+
agent.to(device, fp16=True)
|
| 90 |
+
logger.info(
|
| 91 |
+
f"Successfully built simuleval agent {model_path} on device {device}"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
return agent
|
| 95 |
+
|
| 96 |
+
def process_incoming_bytes(self, incoming_bytes):
|
| 97 |
+
segment, _sr = self._preprocess_wav(incoming_bytes)
|
| 98 |
+
# # segment is array([0, 0, 0, ..., 0, 0, 0], dtype=int16)
|
| 99 |
+
self.input_queue.put_nowait(segment)
|
| 100 |
+
|
| 101 |
+
def get_input_segment(self):
|
| 102 |
+
if self.input_queue.empty():
|
| 103 |
+
return None
|
| 104 |
+
chunk = self.input_queue.get_nowait()
|
| 105 |
+
self.input_queue.task_done()
|
| 106 |
+
return chunk
|
| 107 |
+
|
| 108 |
+
def _preprocess_wav(self, data: Any) -> Tuple[np.ndarray, int]:
|
| 109 |
+
segment, sample_rate = soundfile.read(
|
| 110 |
+
io.BytesIO(data),
|
| 111 |
+
dtype="float32",
|
| 112 |
+
always_2d=True,
|
| 113 |
+
frames=-1,
|
| 114 |
+
start=0,
|
| 115 |
+
format="RAW",
|
| 116 |
+
subtype="PCM_16",
|
| 117 |
+
samplerate=self.incoming_sample_rate,
|
| 118 |
+
channels=1,
|
| 119 |
+
)
|
| 120 |
+
if self.debug:
|
| 121 |
+
self.test_incoming_wav.seek(0, soundfile.SEEK_END)
|
| 122 |
+
self.test_incoming_wav.write(segment)
|
| 123 |
+
|
| 124 |
+
segment = segment.T
|
| 125 |
+
segment, new_sample_rate = convert_waveform(
|
| 126 |
+
segment,
|
| 127 |
+
sample_rate,
|
| 128 |
+
normalize_volume=False,
|
| 129 |
+
to_mono=True,
|
| 130 |
+
to_sample_rate=MODEL_SAMPLE_RATE,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
assert MODEL_SAMPLE_RATE == new_sample_rate
|
| 134 |
+
segment = segment.squeeze(axis=0)
|
| 135 |
+
return segment, new_sample_rate
|
| 136 |
+
|
| 137 |
+
def process_pipeline_impl(self, input_segment):
|
| 138 |
+
try:
|
| 139 |
+
output_segment = self.agent.pushpop(input_segment, self.states)
|
| 140 |
+
if (
|
| 141 |
+
self.states[0].first_input_ts is not None
|
| 142 |
+
and self.first_input_ts is None
|
| 143 |
+
):
|
| 144 |
+
# TODO: this is hacky
|
| 145 |
+
self.first_input_ts = self.states[0].first_input_ts
|
| 146 |
+
|
| 147 |
+
if not output_segment.is_empty:
|
| 148 |
+
self.output_queue.put_nowait(output_segment)
|
| 149 |
+
|
| 150 |
+
if output_segment.finished:
|
| 151 |
+
self.debug_log("OUTPUT SEGMENT IS FINISHED. Resetting states.")
|
| 152 |
+
|
| 153 |
+
for state in self.states:
|
| 154 |
+
state.reset()
|
| 155 |
+
|
| 156 |
+
if self.debug:
|
| 157 |
+
# when we rebuild states, this value is reset to whatever
|
| 158 |
+
# is in the system dir config, which defaults debug=False.
|
| 159 |
+
self.states[0].debug = True
|
| 160 |
+
except Exception as e:
|
| 161 |
+
logger.error(f"Got exception while processing pipeline: {e}")
|
| 162 |
+
traceback.print_exc()
|
| 163 |
+
return input_segment
|
| 164 |
+
|
| 165 |
+
def process_pipeline_loop(self):
|
| 166 |
+
if self.close:
|
| 167 |
+
return # closes the thread
|
| 168 |
+
|
| 169 |
+
self.debug_log("processing_pipeline")
|
| 170 |
+
while not self.close:
|
| 171 |
+
input_segment = self.get_input_segment()
|
| 172 |
+
if input_segment is None:
|
| 173 |
+
if self.states[0].is_fresh_state: # TODO: this is hacky
|
| 174 |
+
time.sleep(0.3)
|
| 175 |
+
else:
|
| 176 |
+
time.sleep(0.03)
|
| 177 |
+
continue
|
| 178 |
+
self.process_pipeline_impl(input_segment)
|
| 179 |
+
self.debug_log("finished processing_pipeline")
|
| 180 |
+
|
| 181 |
+
def process_pipeline_once(self):
|
| 182 |
+
if self.close:
|
| 183 |
+
return
|
| 184 |
+
|
| 185 |
+
self.debug_log("processing pipeline once")
|
| 186 |
+
input_segment = self.get_input_segment()
|
| 187 |
+
if input_segment is None:
|
| 188 |
+
return
|
| 189 |
+
self.process_pipeline_impl(input_segment)
|
| 190 |
+
self.debug_log("finished processing_pipeline_once")
|
| 191 |
+
|
| 192 |
+
def get_output_segment(self):
|
| 193 |
+
if self.output_queue.empty():
|
| 194 |
+
return None
|
| 195 |
+
|
| 196 |
+
output_chunk = self.output_queue.get_nowait()
|
| 197 |
+
self.output_queue.task_done()
|
| 198 |
+
return output_chunk
|
| 199 |
+
|
| 200 |
+
def start(self):
|
| 201 |
+
self.debug_log("starting transcoder in a thread")
|
| 202 |
+
threading.Thread(target=self.process_pipeline_loop).start()
|
| 203 |
+
|
| 204 |
+
def first_translation_time(self):
|
| 205 |
+
return round((self.first_output_ts - self.first_input_ts) / 1000, 2)
|
| 206 |
+
|
| 207 |
+
def get_buffered_output(self) -> SpeechAndTextOutput:
|
| 208 |
+
now = time.time() * 1000
|
| 209 |
+
self.debug_log(f"get_buffered_output queue size: {self.output_queue.qsize()}")
|
| 210 |
+
while not self.output_queue.empty():
|
| 211 |
+
tmp_out = self.get_output_segment()
|
| 212 |
+
if tmp_out and len(tmp_out.content) > 0:
|
| 213 |
+
if not self.output_data_type:
|
| 214 |
+
self.output_data_type = tmp_out.data_type
|
| 215 |
+
if len(self.output_buffer) == 0:
|
| 216 |
+
self.last_output_ts = now
|
| 217 |
+
self._populate_output_buffer(tmp_out)
|
| 218 |
+
self._increment_output_buffer_size(tmp_out)
|
| 219 |
+
|
| 220 |
+
if tmp_out.finished:
|
| 221 |
+
res = self._gather_output_buffer_data(final=True)
|
| 222 |
+
self.output_buffer = []
|
| 223 |
+
self.increment_output_buffer_size = 0
|
| 224 |
+
self.last_output_ts = now
|
| 225 |
+
self.first_output_ts = now
|
| 226 |
+
return res
|
| 227 |
+
|
| 228 |
+
if len(self.output_buffer) > 0 and (
|
| 229 |
+
now - self.last_output_ts >= self.output_buffer_idle_ms
|
| 230 |
+
or self.output_buffer_cur_size >= self.output_buffer_size_limit
|
| 231 |
+
):
|
| 232 |
+
self.last_output_ts = now
|
| 233 |
+
res = self._gather_output_buffer_data(final=False)
|
| 234 |
+
self.output_buffer = []
|
| 235 |
+
self.output_buffer_phoneme_count = 0
|
| 236 |
+
self.first_output_ts = now
|
| 237 |
+
return res
|
| 238 |
+
else:
|
| 239 |
+
return None
|
| 240 |
+
|
| 241 |
+
def _gather_output_buffer_data(self, final):
|
| 242 |
+
if self.output_data_type == "text":
|
| 243 |
+
return SpeechAndTextOutput(text=" ".join(self.output_buffer), final=final)
|
| 244 |
+
elif self.output_data_type == "speech":
|
| 245 |
+
return SpeechAndTextOutput(
|
| 246 |
+
speech_samples=self.output_buffer,
|
| 247 |
+
speech_sample_rate=MODEL_SAMPLE_RATE,
|
| 248 |
+
final=final,
|
| 249 |
+
)
|
| 250 |
+
else:
|
| 251 |
+
raise ValueError(
|
| 252 |
+
f"Invalid output buffer data type: {self.output_data_type}"
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
def _increment_output_buffer_size(self, segment):
|
| 256 |
+
if segment.data_type == "text":
|
| 257 |
+
self.output_buffer_cur_size += self._compute_phoneme_count(segment.content)
|
| 258 |
+
elif segment.data_type == "speech":
|
| 259 |
+
self.output_buffer_cur_size += (
|
| 260 |
+
len(segment.content) / MODEL_SAMPLE_RATE
|
| 261 |
+
) # seconds
|
| 262 |
+
|
| 263 |
+
def _populate_output_buffer(self, segment):
|
| 264 |
+
if segment.data_type == "text":
|
| 265 |
+
self.output_buffer.append(segment.content)
|
| 266 |
+
elif segment.data_type == "speech":
|
| 267 |
+
self.output_buffer += segment.content
|
| 268 |
+
else:
|
| 269 |
+
raise ValueError(f"Invalid segment data type: {segment.data_type}")
|
| 270 |
+
|
| 271 |
+
def _compute_phoneme_count(self, string: str) -> int:
|
| 272 |
+
return len([x for x in self.g2p(string) if x != " "])
|
requirements.txt
CHANGED
|
@@ -1,9 +1,23 @@
|
|
| 1 |
# fairseq2==0.1.0
|
|
|
|
|
|
|
| 2 |
git+https://github.com/mduppes/fairseq2.git@93420c86ba01349ee8f90d7adda439b666b50557
|
| 3 |
-
git+https://github.com/facebookresearch/seamless_communication
|
|
|
|
|
|
|
|
|
|
| 4 |
gradio==3.41.0
|
| 5 |
huggingface_hub==0.16.4
|
| 6 |
torch==2.0.1
|
| 7 |
torchaudio==2.0.2
|
| 8 |
transformers==4.32.1
|
| 9 |
-
pydub
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# fairseq2==0.1.0
|
| 2 |
+
|
| 3 |
+
# Temp to skip
|
| 4 |
git+https://github.com/mduppes/fairseq2.git@93420c86ba01349ee8f90d7adda439b666b50557
|
| 5 |
+
# git+https://github.com/facebookresearch/seamless_communication
|
| 6 |
+
./seamless_communication
|
| 7 |
+
# comment this out to test fairseq1 first
|
| 8 |
+
# git+https://github.com/facebookresearch/SimulEval.git
|
| 9 |
gradio==3.41.0
|
| 10 |
huggingface_hub==0.16.4
|
| 11 |
torch==2.0.1
|
| 12 |
torchaudio==2.0.2
|
| 13 |
transformers==4.32.1
|
| 14 |
+
pydub
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Can't import fairseq1 together.. causes conflict:
|
| 18 |
+
#The conflict is caused by:
|
| 19 |
+
# The user requested simuleval 1.1.0 (from git+ssh://****@github.com/facebookresearch/SimulEval.git@tree_pipeline)
|
| 20 |
+
# seamless-communication 1.0.0 depends on simuleval 1.0.3.dev36+gd84fa60 (from git+https://github.com/mduppes/SimulEval.git@main)
|
| 21 |
+
# From fairseq1 pipeline
|
| 22 |
+
# git+ssh://git@github.com/fairinternal/fairseq-py.git@emma_incremental_decoder
|
| 23 |
+
# git+ssh://git@github.com/facebookresearch/SimulEval.git@tree_pipeline
|
simuleval_transcoder.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from fairseq2.assets.card import AssetCard
|
| 8 |
+
from fairseq2.data import Collater
|
| 9 |
+
from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
|
| 10 |
+
from fairseq2.data.text.text_tokenizer import TextTokenizer
|
| 11 |
+
from fairseq2.data.typing import StringLike
|
| 12 |
+
from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions
|
| 13 |
+
from fairseq2.memory import MemoryBlock
|
| 14 |
+
from fairseq2.typing import DataType, Device
|
| 15 |
+
from torch import Tensor
|
| 16 |
+
from enum import Enum, auto
|
| 17 |
+
from seamless_communication.models.inference.ngram_repeat_block_processor import (
|
| 18 |
+
NGramRepeatBlockProcessor,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
from seamless_communication.models.unity import (
|
| 22 |
+
UnitTokenizer,
|
| 23 |
+
UnitYGenerator,
|
| 24 |
+
UnitYModel,
|
| 25 |
+
load_unity_model,
|
| 26 |
+
load_unity_text_tokenizer,
|
| 27 |
+
load_unity_unit_tokenizer,
|
| 28 |
+
)
|
| 29 |
+
from seamless_communication.models.unity.generator import SequenceToUnitOutput
|
| 30 |
+
from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
from seamless_communication.models.streaming.agents import (
|
| 35 |
+
SileroVADAgent,
|
| 36 |
+
TestTimeWaitKS2TVAD,
|
| 37 |
+
TestTimeWaitKUnityV1M4T
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
### From test_pipeline
|
| 41 |
+
import math
|
| 42 |
+
import soundfile
|
| 43 |
+
from argparse import Namespace, ArgumentParser
|
| 44 |
+
from simuleval.data.segments import SpeechSegment, EmptySegment
|
| 45 |
+
from simuleval.utils import build_system_from_dir
|
| 46 |
+
from pathlib import Path
|
| 47 |
+
import numpy as np
|
| 48 |
+
|
| 49 |
+
class AudioFrontEnd:
|
| 50 |
+
def __init__(self, wav_file, segment_size) -> None:
|
| 51 |
+
self.samples, self.sample_rate = soundfile.read(wav_file)
|
| 52 |
+
# print(len(self.samples), self.samples[:100])
|
| 53 |
+
self.samples = self.samples.tolist()
|
| 54 |
+
self.segment_size = segment_size
|
| 55 |
+
self.step = 0
|
| 56 |
+
def send_segment(self):
|
| 57 |
+
"""
|
| 58 |
+
This is the front-end logic in simuleval instance.py
|
| 59 |
+
"""
|
| 60 |
+
num_samples = math.ceil(self.segment_size / 1000 * self.sample_rate)
|
| 61 |
+
print("self.segment_size", self.segment_size)
|
| 62 |
+
print('num_samples is', num_samples)
|
| 63 |
+
print('self.sample_rate is', self.sample_rate)
|
| 64 |
+
if self.step < len(self.samples):
|
| 65 |
+
if self.step + num_samples >= len(self.samples):
|
| 66 |
+
samples = self.samples[self.step :]
|
| 67 |
+
is_finished = True
|
| 68 |
+
else:
|
| 69 |
+
samples = self.samples[self.step : self.step + num_samples]
|
| 70 |
+
is_finished = False
|
| 71 |
+
self.step = min(self.step + num_samples, len(self.samples))
|
| 72 |
+
# print("len(samples) is", len(samples))
|
| 73 |
+
# import pdb
|
| 74 |
+
# pdb.set_trace()
|
| 75 |
+
segment = SpeechSegment(
|
| 76 |
+
index=self.step / self.sample_rate * 1000,
|
| 77 |
+
content=samples,
|
| 78 |
+
sample_rate=self.sample_rate,
|
| 79 |
+
finished=is_finished,
|
| 80 |
+
)
|
| 81 |
+
else:
|
| 82 |
+
# Finish reading this audio
|
| 83 |
+
segment = EmptySegment(
|
| 84 |
+
index=self.step / self.sample_rate * 1000,
|
| 85 |
+
finished=True,
|
| 86 |
+
)
|
| 87 |
+
return segment
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def load_model_for_inference(
|
| 92 |
+
load_model_fn: Callable[..., nn.Module],
|
| 93 |
+
model_name_or_card: Union[str, AssetCard],
|
| 94 |
+
device: Device,
|
| 95 |
+
dtype: DataType,
|
| 96 |
+
) -> nn.Module:
|
| 97 |
+
model = load_model_fn(model_name_or_card, device=device, dtype=dtype)
|
| 98 |
+
model.eval()
|
| 99 |
+
return model
|
| 100 |
+
|
| 101 |
+
class SimulevalTranscoder:
|
| 102 |
+
# def __init__(self, agent, sample_rate, debug, buffer_limit):
|
| 103 |
+
def __init__(self):
|
| 104 |
+
print("MDUPPES in here", SileroVADAgent, TestTimeWaitKS2TVAD)
|
| 105 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 106 |
+
|
| 107 |
+
device = "cpu"
|
| 108 |
+
print("DEVICE", device)
|
| 109 |
+
model_name_or_card="seamlessM4T_medium"
|
| 110 |
+
vocoder_name_or_card="vocoder_36langs"
|
| 111 |
+
# dtype=torch.float16,
|
| 112 |
+
# For CPU Mode need to use 32, float16 causes errors downstream
|
| 113 |
+
dtype=dtype=torch.float32
|
| 114 |
+
|
| 115 |
+
model: UnitYModel = load_model_for_inference(
|
| 116 |
+
load_unity_model, model_name_or_card, device, dtype
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
print(model, type(model))
|
| 121 |
+
parser = ArgumentParser()
|
| 122 |
+
source_segment_size = 320 # milliseconds
|
| 123 |
+
audio_frontend = AudioFrontEnd(
|
| 124 |
+
wav_file="/checkpoint/mduppes/samples/marta.wav",
|
| 125 |
+
segment_size=source_segment_size,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# mostly taken from S2S first agent: OnlineFeatureExtractorAgent defaults
|
| 129 |
+
SHIFT_SIZE = 10
|
| 130 |
+
WINDOW_SIZE = 25
|
| 131 |
+
SAMPLE_RATE = 16000
|
| 132 |
+
FEATURE_DIM = 80
|
| 133 |
+
|
| 134 |
+
# args and convert to namespace so it can be accesed via .
|
| 135 |
+
args = {
|
| 136 |
+
"shift_size": SHIFT_SIZE,
|
| 137 |
+
"window_size": WINDOW_SIZE,
|
| 138 |
+
"sample_rate": audio_frontend.sample_rate,
|
| 139 |
+
"feature_dim": 160, # from Wav2Vec2Frontend
|
| 140 |
+
"denormalize": False, # not sure..
|
| 141 |
+
"global_stats": None, # default file path containing cmvn stats..
|
| 142 |
+
}
|
| 143 |
+
print(args)
|
| 144 |
+
args = Namespace(**args)
|
| 145 |
+
|
| 146 |
+
pipeline = TestTimeWaitKUnityV1M4T(model, args)
|
| 147 |
+
system_states = pipeline.build_states()
|
| 148 |
+
print('system states')
|
| 149 |
+
print(system_states)
|
| 150 |
+
input_segment = np.empty(0, dtype=np.int16)
|
| 151 |
+
segments = []
|
| 152 |
+
while True:
|
| 153 |
+
speech_segment = audio_frontend.send_segment()
|
| 154 |
+
input_segment = np.concatenate((input_segment, np.array(speech_segment.content)))
|
| 155 |
+
# Translation happens here
|
| 156 |
+
output_segment = pipeline.pushpop(speech_segment, system_states)
|
| 157 |
+
print('pushpop result')
|
| 158 |
+
print(output_segment)
|
| 159 |
+
if output_segment.finished:
|
| 160 |
+
segments.append(input_segment)
|
| 161 |
+
input_segment = np.empty(0, dtype=np.int16)
|
| 162 |
+
print("Resetting states")
|
| 163 |
+
for state in system_states:
|
| 164 |
+
state.reset()
|
| 165 |
+
if speech_segment.finished:
|
| 166 |
+
break
|
| 167 |
+
# The VAD-segmented samples from the full input audio
|
| 168 |
+
for i, seg in enumerate(segments):
|
| 169 |
+
with soundfile.SoundFile(
|
| 170 |
+
Path("/checkpoint/mduppes/samples") / f"marta_{i}.wav",
|
| 171 |
+
mode="w+",
|
| 172 |
+
format="WAV",
|
| 173 |
+
samplerate=16000,
|
| 174 |
+
channels=1,
|
| 175 |
+
) as f:
|
| 176 |
+
f.seek(0, soundfile.SEEK_END)
|
| 177 |
+
f.write(seg)
|
| 178 |
+
|