Commit
·
70eeaf7
1
Parent(s):
099c588
Refactor code to update UI buttons in audio_tab()
Browse files- playground/refs/audio.m4a +0 -0
- playground/refs/audio.npy +0 -0
- playground/refs/test.ipynb +0 -0
- playground/refs/test.py +330 -0
- playground/testapp/audio.mp3 +0 -0
- playground/testapp/index.html +79 -0
- playground/testapp/main.py +292 -0
- playground/testapp/test.ipynb +478 -0
- playground/testapp/test.py +257 -0
playground/refs/audio.m4a
ADDED
|
Binary file (268 kB). View file
|
|
|
playground/refs/audio.npy
ADDED
|
Binary file (102 kB). View file
|
|
|
playground/refs/test.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
playground/refs/test.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import fastapi
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
from silero_vad import get_speech_timestamps, load_silero_vad
|
| 6 |
+
import whisperx
|
| 7 |
+
import edge_tts
|
| 8 |
+
import gc
|
| 9 |
+
import logging
|
| 10 |
+
import time
|
| 11 |
+
from openai import OpenAI
|
| 12 |
+
import threading
|
| 13 |
+
|
| 14 |
+
# Configure logging
|
| 15 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 16 |
+
|
| 17 |
+
# Configure FastAPI
|
| 18 |
+
app = fastapi.FastAPI()
|
| 19 |
+
|
| 20 |
+
# Load Silero VAD model
|
| 21 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 22 |
+
logging.info(f'Using device: {device}')
|
| 23 |
+
vad_model = load_silero_vad().to(device) # Ensure the model is on the correct device
|
| 24 |
+
logging.info('Loaded Silero VAD model')
|
| 25 |
+
|
| 26 |
+
# Load WhisperX model
|
| 27 |
+
whisper_model = whisperx.load_model("tiny", device, compute_type="float16")
|
| 28 |
+
logging.info('Loaded WhisperX model')
|
| 29 |
+
|
| 30 |
+
OPENAI_API_KEY = "sk-proj-gcrtuxd5qzaRYT82Ii3eT3BlbkFJpVQHBc9ZJrmSksLbQc3C" # os.getenv("OPENAI_API_KEY")
|
| 31 |
+
if not OPENAI_API_KEY:
|
| 32 |
+
logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
|
| 33 |
+
raise ValueError("OpenAI API key not found.")
|
| 34 |
+
|
| 35 |
+
# Initialize OpenAI client
|
| 36 |
+
openai_client = OpenAI(api_key=OPENAI_API_KEY)
|
| 37 |
+
logging.info('Initialized OpenAI client')
|
| 38 |
+
|
| 39 |
+
# TTS Voice
|
| 40 |
+
TTS_VOICE = "en-GB-SoniaNeural"
|
| 41 |
+
|
| 42 |
+
# Function to check voice activity using Silero VAD
|
| 43 |
+
def check_vad(audio_data, sample_rate):
|
| 44 |
+
logging.info('Checking voice activity')
|
| 45 |
+
# Resample to 16000 Hz if necessary
|
| 46 |
+
target_sample_rate = 16000
|
| 47 |
+
if sample_rate != target_sample_rate:
|
| 48 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
|
| 49 |
+
audio_tensor = resampler(torch.from_numpy(audio_data))
|
| 50 |
+
else:
|
| 51 |
+
audio_tensor = torch.from_numpy(audio_data)
|
| 52 |
+
audio_tensor = audio_tensor.to(device)
|
| 53 |
+
|
| 54 |
+
# Log audio data details
|
| 55 |
+
logging.info(f'Audio tensor shape: {audio_tensor.shape}, dtype: {audio_tensor.dtype}, device: {audio_tensor.device}')
|
| 56 |
+
|
| 57 |
+
# Get speech timestamps
|
| 58 |
+
speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate)
|
| 59 |
+
logging.info(f'Found {len(speech_timestamps)} speech timestamps')
|
| 60 |
+
return len(speech_timestamps) > 0
|
| 61 |
+
|
| 62 |
+
# Function to transcribe audio using WhisperX
|
| 63 |
+
def transcript(audio_data, sample_rate):
|
| 64 |
+
logging.info('Transcribing audio')
|
| 65 |
+
# Resample to 16000 Hz if necessary
|
| 66 |
+
target_sample_rate = 16000
|
| 67 |
+
if sample_rate != target_sample_rate:
|
| 68 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
|
| 69 |
+
audio_data = resampler(torch.from_numpy(audio_data)).numpy()
|
| 70 |
+
else:
|
| 71 |
+
audio_data = audio_data
|
| 72 |
+
|
| 73 |
+
# Transcribe
|
| 74 |
+
batch_size = 16 # Adjust as needed
|
| 75 |
+
result = whisper_model.transcribe(audio_data, batch_size=batch_size)
|
| 76 |
+
text = result["segments"][0]["text"] if len(result["segments"]) > 0 else ""
|
| 77 |
+
logging.info(f'Transcription result: {text}')
|
| 78 |
+
# Clear GPU memory
|
| 79 |
+
del result
|
| 80 |
+
gc.collect()
|
| 81 |
+
if device == 'cuda':
|
| 82 |
+
torch.cuda.empty_cache()
|
| 83 |
+
return text
|
| 84 |
+
|
| 85 |
+
# Function to get streaming response from OpenAI API
|
| 86 |
+
def llm(text):
|
| 87 |
+
logging.info('Getting response from OpenAI API')
|
| 88 |
+
response = openai_client.chat.completions.create(
|
| 89 |
+
model="gpt-4o", # Updated to a more recent model
|
| 90 |
+
messages=[
|
| 91 |
+
{"role": "system", "content": "You respond to the following transcript from the conversation that you are having with the user."},
|
| 92 |
+
{"role": "user", "content": text}
|
| 93 |
+
],
|
| 94 |
+
stream=True,
|
| 95 |
+
temperature=0.7, # Optional: Adjust as needed
|
| 96 |
+
top_p=0.9, # Optional: Adjust as needed
|
| 97 |
+
)
|
| 98 |
+
for chunk in response:
|
| 99 |
+
yield chunk.choices[0].delta.content
|
| 100 |
+
|
| 101 |
+
# Function to perform TTS per sentence using Edge-TTS
|
| 102 |
+
def tts_streaming(text_stream):
|
| 103 |
+
logging.info('Performing TTS')
|
| 104 |
+
buffer = ""
|
| 105 |
+
punctuation = {'.', '!', '?'}
|
| 106 |
+
for text_chunk in text_stream:
|
| 107 |
+
if text_chunk is not None:
|
| 108 |
+
buffer += text_chunk
|
| 109 |
+
# Check for sentence completion
|
| 110 |
+
sentences = []
|
| 111 |
+
start = 0
|
| 112 |
+
for i, char in enumerate(buffer):
|
| 113 |
+
if (char in punctuation):
|
| 114 |
+
sentences.append(buffer[start:i+1].strip())
|
| 115 |
+
start = i+1
|
| 116 |
+
buffer = buffer[start:]
|
| 117 |
+
|
| 118 |
+
for sentence in sentences:
|
| 119 |
+
if sentence:
|
| 120 |
+
communicate = edge_tts.Communicate(sentence, TTS_VOICE)
|
| 121 |
+
for chunk in communicate.stream_sync():
|
| 122 |
+
if chunk["type"] == "audio":
|
| 123 |
+
yield chunk["data"]
|
| 124 |
+
# Process any remaining text
|
| 125 |
+
if buffer.strip():
|
| 126 |
+
communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)
|
| 127 |
+
for chunk in communicate.stream_sync():
|
| 128 |
+
if chunk["type"] == "audio":
|
| 129 |
+
yield chunk["data"]
|
| 130 |
+
|
| 131 |
+
# Function to handle LLM and TTS
|
| 132 |
+
def llm_and_tts(transcribed_text, state):
|
| 133 |
+
logging.info('Handling LLM and TTS')
|
| 134 |
+
# Get streaming response from LLM
|
| 135 |
+
for text_chunk in llm(transcribed_text):
|
| 136 |
+
if state.get('stop_signal'):
|
| 137 |
+
logging.info('LLM and TTS task stopped')
|
| 138 |
+
break
|
| 139 |
+
# Get audio data from TTS
|
| 140 |
+
for audio_chunk in tts_streaming([text_chunk]):
|
| 141 |
+
if state.get('stop_signal'):
|
| 142 |
+
logging.info('LLM and TTS task stopped during TTS')
|
| 143 |
+
break
|
| 144 |
+
yield np.frombuffer(audio_chunk, dtype=np.int16)
|
| 145 |
+
|
| 146 |
+
state = {
|
| 147 |
+
'mode': 'idle',
|
| 148 |
+
'chunk_queue': [],
|
| 149 |
+
'transcription': '',
|
| 150 |
+
'in_transcription': False,
|
| 151 |
+
'previous_no_vad_audio': [],
|
| 152 |
+
'llm_task': None,
|
| 153 |
+
'instream': None,
|
| 154 |
+
'stop_signal': False,
|
| 155 |
+
'args': {
|
| 156 |
+
'sample_rate': 16000,
|
| 157 |
+
'chunk_size': 0.5, # seconds
|
| 158 |
+
'transcript_chunk_size': 2, # seconds
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
def transcript_loop():
|
| 163 |
+
while True:
|
| 164 |
+
if len(state['chunk_queue']) > 0:
|
| 165 |
+
accumulated_audio = np.concatenate(state['chunk_queue'])
|
| 166 |
+
total_samples = sum(len(chunk) for chunk in state['chunk_queue'])
|
| 167 |
+
total_duration = total_samples / state['sample_rate']
|
| 168 |
+
|
| 169 |
+
# Run transcription on the first 2 seconds if len > 3 seconds
|
| 170 |
+
if total_duration > 3.0 and state['in_transcription'] == True:
|
| 171 |
+
first_two_seconds_samples = int(2.0 * state['sample_rate'])
|
| 172 |
+
first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples]
|
| 173 |
+
transcribed_text = transcript(first_two_seconds_audio, state['sample_rate'])
|
| 174 |
+
state['transcription'] += transcribed_text
|
| 175 |
+
remaining_audio = accumulated_audio[first_two_seconds_samples:]
|
| 176 |
+
state['chunk_queue'] = [remaining_audio]
|
| 177 |
+
else: # Run transcription on the accumulated audio
|
| 178 |
+
transcribed_text = transcript(accumulated_audio, state['sample_rate'])
|
| 179 |
+
state['transcription'] += transcribed_text
|
| 180 |
+
state['chunk_queue'] = []
|
| 181 |
+
state['in_transcription'] = False
|
| 182 |
+
else:
|
| 183 |
+
time.sleep(0.1)
|
| 184 |
+
|
| 185 |
+
if len(state['chunk_queue']) == 0 and state['mode'] == any(['idle', 'processing']):
|
| 186 |
+
state['in_transcription'] = False
|
| 187 |
+
break
|
| 188 |
+
|
| 189 |
+
def process_audio(audio_chunk):
|
| 190 |
+
# returns output audio
|
| 191 |
+
|
| 192 |
+
sample_rate, audio_data = audio_chunk
|
| 193 |
+
audio_data = np.array(audio_data, dtype=np.float32)
|
| 194 |
+
|
| 195 |
+
# convert to mono if necessary
|
| 196 |
+
if audio_data.ndim > 1:
|
| 197 |
+
audio_data = np.mean(audio_data, axis=1)
|
| 198 |
+
|
| 199 |
+
mode = state['mode']
|
| 200 |
+
chunk_queue = state['chunk_queue']
|
| 201 |
+
transcription = state['transcription']
|
| 202 |
+
in_transcription = state['in_transcription']
|
| 203 |
+
previous_no_vad_audio = state['previous_no_vad_audio']
|
| 204 |
+
llm_task = state['llm_task']
|
| 205 |
+
instream = state['instream']
|
| 206 |
+
stop_signal = state['stop_signal']
|
| 207 |
+
args = state['args']
|
| 208 |
+
|
| 209 |
+
args['sample_rate'] = sample_rate
|
| 210 |
+
|
| 211 |
+
# check for voice activity
|
| 212 |
+
vad = check_vad(audio_data, sample_rate)
|
| 213 |
+
|
| 214 |
+
if vad:
|
| 215 |
+
logging.info(f'Voice activity detected in mode: {mode}')
|
| 216 |
+
if mode == 'idle':
|
| 217 |
+
mode = 'listening'
|
| 218 |
+
elif mode == 'speaking':
|
| 219 |
+
# Stop llm and tts tasks
|
| 220 |
+
if llm_task and llm_task.is_alive():
|
| 221 |
+
# Implement task cancellation logic if possible
|
| 222 |
+
logging.info('Stopping LLM and TTS tasks')
|
| 223 |
+
# Since we cannot kill threads directly, we need to handle this in the tasks
|
| 224 |
+
stop_signal = True
|
| 225 |
+
llm_task.join()
|
| 226 |
+
mode = 'listening'
|
| 227 |
+
|
| 228 |
+
if mode == 'listening':
|
| 229 |
+
if previous_no_vad_audio is not None:
|
| 230 |
+
chunk_queue.append(previous_no_vad_audio)
|
| 231 |
+
previous_no_vad_audio = None
|
| 232 |
+
# Accumulate audio chunks
|
| 233 |
+
chunk_queue.append(audio_data)
|
| 234 |
+
|
| 235 |
+
# Start transcription thread if not already running
|
| 236 |
+
if not in_transcription:
|
| 237 |
+
in_transcription = True
|
| 238 |
+
transcription_task = threading.Thread(target=transcript_loop, args=(chunk_queue, sample_rate))
|
| 239 |
+
transcription_task.start()
|
| 240 |
+
|
| 241 |
+
elif mode == 'speaking':
|
| 242 |
+
# Continue accumulating audio chunks
|
| 243 |
+
chunk_queue.append(audio_data)
|
| 244 |
+
else:
|
| 245 |
+
logging.info(f'No voice activity detected in mode: {mode}')
|
| 246 |
+
if mode == 'listening':
|
| 247 |
+
# Add the last chunk to queue
|
| 248 |
+
chunk_queue.append(audio_data)
|
| 249 |
+
|
| 250 |
+
# Change mode to processing
|
| 251 |
+
mode = 'processing'
|
| 252 |
+
|
| 253 |
+
# Wait for transcription to complete
|
| 254 |
+
while in_transcription:
|
| 255 |
+
time.sleep(0.1)
|
| 256 |
+
|
| 257 |
+
# Check if transcription is complete
|
| 258 |
+
if len(chunk_queue) == 0:
|
| 259 |
+
# Start LLM and TTS tasks
|
| 260 |
+
if not llm_task or not llm_task.is_alive():
|
| 261 |
+
stop_signal = False
|
| 262 |
+
llm_task = threading.Thread(target=llm_and_tts, args=(transcription, state))
|
| 263 |
+
llm_task.start()
|
| 264 |
+
|
| 265 |
+
if mode == 'processing':
|
| 266 |
+
# Wait for LLM and TTS tasks to start yielding audio
|
| 267 |
+
if llm_task and llm_task.is_alive():
|
| 268 |
+
mode = 'responding'
|
| 269 |
+
|
| 270 |
+
if mode == 'responding':
|
| 271 |
+
for audio_chunk in llm_task:
|
| 272 |
+
if instream is None:
|
| 273 |
+
instream = audio_chunk
|
| 274 |
+
else:
|
| 275 |
+
instream = np.concatenate((instream, audio_chunk))
|
| 276 |
+
|
| 277 |
+
# Send audio to output stream
|
| 278 |
+
yield instream
|
| 279 |
+
|
| 280 |
+
# Cleanup
|
| 281 |
+
llm_task = None
|
| 282 |
+
transcription = ''
|
| 283 |
+
mode = 'idle'
|
| 284 |
+
|
| 285 |
+
# Updaate state
|
| 286 |
+
state['mode'] = mode
|
| 287 |
+
state['chunk_queue'] = chunk_queue
|
| 288 |
+
state['transcription'] = transcription
|
| 289 |
+
state['in_transcription'] = in_transcription
|
| 290 |
+
state['previous_no_vad_audio'] = previous_no_vad_audio
|
| 291 |
+
state['llm_task'] = llm_task
|
| 292 |
+
state['instream'] = instream
|
| 293 |
+
state['stop_signal'] = stop_signal
|
| 294 |
+
state['args'] = args
|
| 295 |
+
|
| 296 |
+
# Store previous audio chunk with no voice activity
|
| 297 |
+
previous_no_vad_audio = audio_data
|
| 298 |
+
|
| 299 |
+
# Update state
|
| 300 |
+
state['mode'] = mode
|
| 301 |
+
state['chunk_queue'] = chunk_queue
|
| 302 |
+
state['transcription'] = transcription
|
| 303 |
+
state['in_transcription'] = in_transcription
|
| 304 |
+
state['previous_no_vad_audio'] = previous_no_vad_audio
|
| 305 |
+
state['llm_task'] = llm_task
|
| 306 |
+
state['instream'] = instream
|
| 307 |
+
state['stop_signal'] = stop_signal
|
| 308 |
+
state['args'] = args
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
@app.websocket('/ws')
|
| 312 |
+
def websocket_endpoint(websocket: fastapi.WebSocket):
|
| 313 |
+
logging.info('WebSocket connection established')
|
| 314 |
+
try:
|
| 315 |
+
while True:
|
| 316 |
+
time.sleep(state['args']['chunk_size'])
|
| 317 |
+
audio_chunk = websocket.receive_bytes()
|
| 318 |
+
if audio_chunk is None:
|
| 319 |
+
break
|
| 320 |
+
for audio_data in process_audio(audio_chunk):
|
| 321 |
+
websocket.send_bytes(audio_data.tobytes())
|
| 322 |
+
except Exception as e:
|
| 323 |
+
logging.error(f'WebSocket error: {e}')
|
| 324 |
+
finally:
|
| 325 |
+
logging.info('WebSocket connection closed')
|
| 326 |
+
websocket.close()
|
| 327 |
+
|
| 328 |
+
@app.get('/')
|
| 329 |
+
def index():
|
| 330 |
+
return fastapi.FileResponse('index.html')
|
playground/testapp/audio.mp3
ADDED
|
Binary file (386 kB). View file
|
|
|
playground/testapp/index.html
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>Voice Assistant</title>
|
| 7 |
+
<style>
|
| 8 |
+
body {
|
| 9 |
+
font-family: Arial, sans-serif;
|
| 10 |
+
margin: 20px;
|
| 11 |
+
}
|
| 12 |
+
#transcription {
|
| 13 |
+
margin-top: 20px;
|
| 14 |
+
padding: 10px;
|
| 15 |
+
border: 1px solid #ccc;
|
| 16 |
+
height: 150px;
|
| 17 |
+
overflow-y: auto;
|
| 18 |
+
}
|
| 19 |
+
#audio-player {
|
| 20 |
+
margin-top: 20px;
|
| 21 |
+
}
|
| 22 |
+
</style>
|
| 23 |
+
</head>
|
| 24 |
+
<body>
|
| 25 |
+
<h1>Voice Assistant</h1>
|
| 26 |
+
<button id="start-btn">Start Recording</button>
|
| 27 |
+
<button id="stop-btn" disabled>Stop Recording</button>
|
| 28 |
+
<div id="transcription"></div>
|
| 29 |
+
<audio id="audio-player" controls></audio>
|
| 30 |
+
|
| 31 |
+
<script>
|
| 32 |
+
const startBtn = document.getElementById('start-btn');
|
| 33 |
+
const stopBtn = document.getElementById('stop-btn');
|
| 34 |
+
const transcriptionDiv = document.getElementById('transcription');
|
| 35 |
+
const audioPlayer = document.getElementById('audio-player');
|
| 36 |
+
let websocket;
|
| 37 |
+
let mediaRecorder;
|
| 38 |
+
let audioChunks = [];
|
| 39 |
+
|
| 40 |
+
startBtn.addEventListener('click', async () => {
|
| 41 |
+
startBtn.disabled = true;
|
| 42 |
+
stopBtn.disabled = false;
|
| 43 |
+
|
| 44 |
+
websocket = new WebSocket('ws://localhost:8000/ws');
|
| 45 |
+
websocket.binaryType = 'arraybuffer';
|
| 46 |
+
|
| 47 |
+
websocket.onmessage = (event) => {
|
| 48 |
+
if (event.data instanceof ArrayBuffer) {
|
| 49 |
+
const audioBlob = new Blob([event.data], { type: 'audio/wav' });
|
| 50 |
+
audioPlayer.src = URL.createObjectURL(audioBlob);
|
| 51 |
+
audioPlayer.play();
|
| 52 |
+
} else {
|
| 53 |
+
transcriptionDiv.innerText += event.data + '\n';
|
| 54 |
+
}
|
| 55 |
+
};
|
| 56 |
+
|
| 57 |
+
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
| 58 |
+
mediaRecorder = new MediaRecorder(stream);
|
| 59 |
+
|
| 60 |
+
mediaRecorder.ondataavailable = (event) => {
|
| 61 |
+
if (event.data.size > 0) {
|
| 62 |
+
audioChunks.push(event.data);
|
| 63 |
+
websocket.send(event.data);
|
| 64 |
+
}
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
mediaRecorder.start(1000); // Send audio data every second
|
| 68 |
+
});
|
| 69 |
+
|
| 70 |
+
stopBtn.addEventListener('click', () => {
|
| 71 |
+
startBtn.disabled = false;
|
| 72 |
+
stopBtn.disabled = true;
|
| 73 |
+
|
| 74 |
+
mediaRecorder.stop();
|
| 75 |
+
websocket.close();
|
| 76 |
+
});
|
| 77 |
+
</script>
|
| 78 |
+
</body>
|
| 79 |
+
</html>
|
playground/testapp/main.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import fastapi
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
from silero_vad import get_speech_timestamps, load_silero_vad
|
| 6 |
+
import whisperx
|
| 7 |
+
import edge_tts
|
| 8 |
+
import gc
|
| 9 |
+
import logging
|
| 10 |
+
import time
|
| 11 |
+
import os
|
| 12 |
+
from openai import AsyncOpenAI
|
| 13 |
+
import asyncio
|
| 14 |
+
|
| 15 |
+
# Configure logging
|
| 16 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 17 |
+
|
| 18 |
+
# Configure FastAPI
|
| 19 |
+
app = fastapi.FastAPI()
|
| 20 |
+
|
| 21 |
+
# Load Silero VAD model
|
| 22 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 23 |
+
logging.info(f'Using device: {device}')
|
| 24 |
+
vad_model = load_silero_vad().to(device)
|
| 25 |
+
logging.info('Loaded Silero VAD model')
|
| 26 |
+
|
| 27 |
+
# Load WhisperX model
|
| 28 |
+
whisper_model = whisperx.load_model("tiny", device, compute_type="float16")
|
| 29 |
+
logging.info('Loaded WhisperX model')
|
| 30 |
+
|
| 31 |
+
OPENAI_API_KEY = "sk-proj-gcrtuxd5qzaRYT82Ii3eT3BlbkFJpVQHBc9ZJrmSksLbQc3C"
|
| 32 |
+
if not OPENAI_API_KEY:
|
| 33 |
+
logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
|
| 34 |
+
raise ValueError("OpenAI API key not found.")
|
| 35 |
+
logging.info('Initialized OpenAI client')
|
| 36 |
+
aclient = AsyncOpenAI(api_key=OPENAI_API_KEY) # Corrected import
|
| 37 |
+
|
| 38 |
+
# TTS Voice
|
| 39 |
+
TTS_VOICE = "en-GB-SoniaNeural"
|
| 40 |
+
|
| 41 |
+
# Function to check voice activity using Silero VAD
|
| 42 |
+
def check_vad(audio_data, sample_rate):
|
| 43 |
+
logging.info('Checking voice activity')
|
| 44 |
+
target_sample_rate = 16000
|
| 45 |
+
if sample_rate != target_sample_rate:
|
| 46 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
|
| 47 |
+
audio_tensor = resampler(torch.from_numpy(audio_data))
|
| 48 |
+
else:
|
| 49 |
+
audio_tensor = torch.from_numpy(audio_data)
|
| 50 |
+
audio_tensor = audio_tensor.to(device)
|
| 51 |
+
|
| 52 |
+
speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate)
|
| 53 |
+
logging.info(f'Found {len(speech_timestamps)} speech timestamps')
|
| 54 |
+
return len(speech_timestamps) > 0
|
| 55 |
+
|
| 56 |
+
# Async function to transcribe audio using WhisperX
|
| 57 |
+
def transcript_sync(audio_data, sample_rate):
|
| 58 |
+
logging.info('Transcribing audio')
|
| 59 |
+
target_sample_rate = 16000
|
| 60 |
+
if sample_rate != target_sample_rate:
|
| 61 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
|
| 62 |
+
audio_data = resampler(torch.from_numpy(audio_data)).numpy()
|
| 63 |
+
else:
|
| 64 |
+
audio_data = audio_data
|
| 65 |
+
|
| 66 |
+
batch_size = 16 # Adjust as needed
|
| 67 |
+
result = whisper_model.transcribe(audio_data, batch_size=batch_size)
|
| 68 |
+
text = result["segments"][0]["text"] if len(result["segments"]) > 0 else ""
|
| 69 |
+
logging.info(f'Transcription result: {text}')
|
| 70 |
+
del result
|
| 71 |
+
gc.collect()
|
| 72 |
+
if device == 'cuda':
|
| 73 |
+
torch.cuda.empty_cache()
|
| 74 |
+
return text
|
| 75 |
+
|
| 76 |
+
async def transcript(audio_data, sample_rate):
|
| 77 |
+
loop = asyncio.get_running_loop()
|
| 78 |
+
text = await loop.run_in_executor(None, transcript_sync, audio_data, sample_rate)
|
| 79 |
+
return text
|
| 80 |
+
|
| 81 |
+
# Async function to get streaming response from OpenAI API
|
| 82 |
+
async def llm(text):
|
| 83 |
+
logging.info('Getting response from OpenAI API')
|
| 84 |
+
response = await aclient.chat.completions.create(model="gpt-4", # Updated to a more recent model
|
| 85 |
+
messages=[
|
| 86 |
+
{"role": "system", "content": "You respond to the following transcript from the conversation that you are having with the user."},
|
| 87 |
+
{"role": "user", "content": text}
|
| 88 |
+
],
|
| 89 |
+
stream=True,
|
| 90 |
+
temperature=0.7,
|
| 91 |
+
top_p=0.9)
|
| 92 |
+
async for chunk in response:
|
| 93 |
+
yield chunk.choices[0].delta.content
|
| 94 |
+
|
| 95 |
+
# Async function to perform TTS using Edge-TTS
|
| 96 |
+
async def tts_streaming(text_stream):
|
| 97 |
+
logging.info('Performing TTS')
|
| 98 |
+
buffer = ""
|
| 99 |
+
punctuation = {'.', '!', '?'}
|
| 100 |
+
for text_chunk in text_stream:
|
| 101 |
+
if text_chunk is not None:
|
| 102 |
+
buffer += text_chunk
|
| 103 |
+
# Check for sentence completion
|
| 104 |
+
sentences = []
|
| 105 |
+
start = 0
|
| 106 |
+
for i, char in enumerate(buffer):
|
| 107 |
+
if char in punctuation:
|
| 108 |
+
sentences.append(buffer[start:i+1].strip())
|
| 109 |
+
start = i+1
|
| 110 |
+
buffer = buffer[start:]
|
| 111 |
+
|
| 112 |
+
for sentence in sentences:
|
| 113 |
+
if sentence:
|
| 114 |
+
communicate = edge_tts.Communicate(sentence, TTS_VOICE)
|
| 115 |
+
async for chunk in communicate.stream():
|
| 116 |
+
if chunk["type"] == "audio":
|
| 117 |
+
yield chunk["data"]
|
| 118 |
+
# Process any remaining text
|
| 119 |
+
if buffer.strip():
|
| 120 |
+
communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)
|
| 121 |
+
async for chunk in communicate.stream():
|
| 122 |
+
if chunk["type"] == "audio":
|
| 123 |
+
yield chunk["data"]
|
| 124 |
+
|
| 125 |
+
class Conversation:
|
| 126 |
+
def __init__(self):
|
| 127 |
+
self.mode = 'idle'
|
| 128 |
+
self.chunk_queue = []
|
| 129 |
+
self.transcription = ''
|
| 130 |
+
self.in_transcription = False
|
| 131 |
+
self.previous_no_vad_audio = None
|
| 132 |
+
self.llm_task = None
|
| 133 |
+
self.transcription_task = None
|
| 134 |
+
self.stop_signal = False
|
| 135 |
+
self.sample_rate = 16000 # default sample rate
|
| 136 |
+
self.instream = None
|
| 137 |
+
|
| 138 |
+
async def process_audio(self, audio_chunk):
|
| 139 |
+
sample_rate, audio_data = audio_chunk
|
| 140 |
+
self.sample_rate = sample_rate
|
| 141 |
+
audio_data = np.array(audio_data, dtype=np.float32)
|
| 142 |
+
|
| 143 |
+
# convert to mono if necessary
|
| 144 |
+
if audio_data.ndim > 1:
|
| 145 |
+
audio_data = np.mean(audio_data, axis=1)
|
| 146 |
+
|
| 147 |
+
# check for voice activity
|
| 148 |
+
vad = check_vad(audio_data, sample_rate)
|
| 149 |
+
|
| 150 |
+
if vad:
|
| 151 |
+
logging.info(f'Voice activity detected in mode: {self.mode}')
|
| 152 |
+
if self.mode == 'idle':
|
| 153 |
+
self.mode = 'listening'
|
| 154 |
+
elif self.mode == 'speaking':
|
| 155 |
+
# Stop llm and tts tasks
|
| 156 |
+
if self.llm_task and not self.llm_task.done():
|
| 157 |
+
logging.info('Stopping LLM and TTS tasks')
|
| 158 |
+
self.stop_signal = True
|
| 159 |
+
await self.llm_task
|
| 160 |
+
self.mode = 'listening'
|
| 161 |
+
|
| 162 |
+
if self.mode == 'listening':
|
| 163 |
+
if self.previous_no_vad_audio is not None:
|
| 164 |
+
self.chunk_queue.append(self.previous_no_vad_audio)
|
| 165 |
+
self.previous_no_vad_audio = None
|
| 166 |
+
# Accumulate audio chunks
|
| 167 |
+
self.chunk_queue.append(audio_data)
|
| 168 |
+
|
| 169 |
+
# Start transcription task if not already running
|
| 170 |
+
if not self.in_transcription:
|
| 171 |
+
self.in_transcription = True
|
| 172 |
+
self.transcription_task = asyncio.create_task(self.transcript_loop())
|
| 173 |
+
|
| 174 |
+
else:
|
| 175 |
+
logging.info(f'No voice activity detected in mode: {self.mode}')
|
| 176 |
+
if self.mode == 'listening':
|
| 177 |
+
# Add the last chunk to queue
|
| 178 |
+
self.chunk_queue.append(audio_data)
|
| 179 |
+
|
| 180 |
+
# Change mode to processing
|
| 181 |
+
self.mode = 'processing'
|
| 182 |
+
|
| 183 |
+
# Wait for transcription to complete
|
| 184 |
+
while self.in_transcription:
|
| 185 |
+
await asyncio.sleep(0.1)
|
| 186 |
+
|
| 187 |
+
# Check if transcription is complete
|
| 188 |
+
if len(self.chunk_queue) == 0:
|
| 189 |
+
# Start LLM and TTS tasks
|
| 190 |
+
if not self.llm_task or self.llm_task.done():
|
| 191 |
+
self.stop_signal = False
|
| 192 |
+
self.llm_task = self.llm_and_tts()
|
| 193 |
+
self.mode = 'responding'
|
| 194 |
+
|
| 195 |
+
if self.mode == 'responding':
|
| 196 |
+
async for audio_chunk in self.llm_task:
|
| 197 |
+
if self.instream is None:
|
| 198 |
+
self.instream = audio_chunk
|
| 199 |
+
else:
|
| 200 |
+
self.instream = np.concatenate((self.instream, audio_chunk))
|
| 201 |
+
# Send audio to output stream
|
| 202 |
+
yield self.instream
|
| 203 |
+
|
| 204 |
+
# Cleanup
|
| 205 |
+
self.llm_task = None
|
| 206 |
+
self.transcription = ''
|
| 207 |
+
self.mode = 'idle'
|
| 208 |
+
self.instream = None
|
| 209 |
+
|
| 210 |
+
# Store previous audio chunk with no voice activity
|
| 211 |
+
self.previous_no_vad_audio = audio_data
|
| 212 |
+
|
| 213 |
+
async def transcript_loop(self):
|
| 214 |
+
while True:
|
| 215 |
+
if len(self.chunk_queue) > 0:
|
| 216 |
+
accumulated_audio = np.concatenate(self.chunk_queue)
|
| 217 |
+
total_samples = len(accumulated_audio)
|
| 218 |
+
total_duration = total_samples / self.sample_rate
|
| 219 |
+
|
| 220 |
+
if total_duration > 3.0 and self.in_transcription == True:
|
| 221 |
+
first_two_seconds_samples = int(2.0 * self.sample_rate)
|
| 222 |
+
first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples]
|
| 223 |
+
transcribed_text = await transcript(first_two_seconds_audio, self.sample_rate)
|
| 224 |
+
self.transcription += transcribed_text
|
| 225 |
+
remaining_audio = accumulated_audio[first_two_seconds_samples:]
|
| 226 |
+
self.chunk_queue = [remaining_audio]
|
| 227 |
+
else:
|
| 228 |
+
transcribed_text = await transcript(accumulated_audio, self.sample_rate)
|
| 229 |
+
self.transcription += transcribed_text
|
| 230 |
+
self.chunk_queue = []
|
| 231 |
+
self.in_transcription = False
|
| 232 |
+
else:
|
| 233 |
+
await asyncio.sleep(0.1)
|
| 234 |
+
|
| 235 |
+
if len(self.chunk_queue) == 0 and self.mode in ['idle', 'processing']:
|
| 236 |
+
self.in_transcription = False
|
| 237 |
+
break
|
| 238 |
+
|
| 239 |
+
async def llm_and_tts(self):
|
| 240 |
+
logging.info('Handling LLM and TTS')
|
| 241 |
+
async for text_chunk in llm(self.transcription):
|
| 242 |
+
if self.stop_signal:
|
| 243 |
+
logging.info('LLM and TTS task stopped')
|
| 244 |
+
break
|
| 245 |
+
async for audio_chunk in tts_streaming([text_chunk]):
|
| 246 |
+
if self.stop_signal:
|
| 247 |
+
logging.info('LLM and TTS task stopped during TTS')
|
| 248 |
+
break
|
| 249 |
+
yield np.frombuffer(audio_chunk, dtype=np.int16)
|
| 250 |
+
|
| 251 |
+
@app.websocket('/ws')
|
| 252 |
+
async def websocket_endpoint(websocket: fastapi.WebSocket):
|
| 253 |
+
await websocket.accept()
|
| 254 |
+
logging.info('WebSocket connection established')
|
| 255 |
+
conversation = Conversation()
|
| 256 |
+
audio_buffer = []
|
| 257 |
+
buffer_duration = 0.5 # 500ms
|
| 258 |
+
try:
|
| 259 |
+
while True:
|
| 260 |
+
audio_chunk_bytes = await websocket.receive_bytes()
|
| 261 |
+
if audio_chunk_bytes is None:
|
| 262 |
+
break
|
| 263 |
+
|
| 264 |
+
audio_chunk = (conversation.sample_rate, np.frombuffer(audio_chunk_bytes, dtype=np.int16))
|
| 265 |
+
audio_buffer.append(audio_chunk[1])
|
| 266 |
+
|
| 267 |
+
# Calculate the duration of the buffered audio
|
| 268 |
+
total_samples = sum(len(chunk) for chunk in audio_buffer)
|
| 269 |
+
total_duration = total_samples / conversation.sample_rate
|
| 270 |
+
|
| 271 |
+
if total_duration >= buffer_duration:
|
| 272 |
+
# Concatenate buffered audio chunks
|
| 273 |
+
buffered_audio = np.concatenate(audio_buffer)
|
| 274 |
+
audio_buffer = [] # Reset buffer
|
| 275 |
+
|
| 276 |
+
# Process the buffered audio
|
| 277 |
+
async for audio_data in conversation.process_audio((conversation.sample_rate, buffered_audio)):
|
| 278 |
+
if audio_data is not None:
|
| 279 |
+
await websocket.send_bytes(audio_data.tobytes())
|
| 280 |
+
except Exception as e:
|
| 281 |
+
logging.error(f'WebSocket error: {e}')
|
| 282 |
+
finally:
|
| 283 |
+
logging.info('WebSocket connection closed')
|
| 284 |
+
await websocket.close()
|
| 285 |
+
|
| 286 |
+
@app.get('/')
|
| 287 |
+
def index():
|
| 288 |
+
return fastapi.responses.FileResponse('index.html')
|
| 289 |
+
|
| 290 |
+
if __name__ == '__main__':
|
| 291 |
+
import uvicorn
|
| 292 |
+
uvicorn.run(app, host='0.0.0.0', port=8000)
|
playground/testapp/test.ipynb
ADDED
|
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import fastapi\n",
|
| 10 |
+
"import numpy as np\n",
|
| 11 |
+
"import torch\n",
|
| 12 |
+
"import torchaudio\n",
|
| 13 |
+
"from silero_vad import get_speech_timestamps, load_silero_vad\n",
|
| 14 |
+
"import whisperx\n",
|
| 15 |
+
"import edge_tts\n",
|
| 16 |
+
"import gc\n",
|
| 17 |
+
"import logging\n",
|
| 18 |
+
"import time\n",
|
| 19 |
+
"from openai import OpenAI\n",
|
| 20 |
+
"import threading\n",
|
| 21 |
+
"import asyncio\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"# Configure logging\n",
|
| 24 |
+
"logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"# Configure FastAPI\n",
|
| 27 |
+
"app = fastapi.FastAPI()\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"# Load Silero VAD model\n",
|
| 30 |
+
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
| 31 |
+
"logging.info(f'Using device: {device}')\n",
|
| 32 |
+
"vad_model = load_silero_vad().to(device) # Ensure the model is on the correct device\n",
|
| 33 |
+
"logging.info('Loaded Silero VAD model')\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"# Load WhisperX model\n",
|
| 36 |
+
"whisper_model = whisperx.load_model(\"tiny\", device, compute_type=\"float16\")\n",
|
| 37 |
+
"logging.info('Loaded WhisperX model')\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"# OpenAI API Key from environment variable for security\n",
|
| 40 |
+
"OPENAI_API_KEY = \"sk-proj-gcrtuxd5qzaRYT82Ii3eT3BlbkFJpVQHBc9ZJrmSksLbQc3C\" # os.getenv(\"OPENAI_API_KEY\")\n",
|
| 41 |
+
"if not OPENAI_API_KEY:\n",
|
| 42 |
+
" logging.error(\"OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.\")\n",
|
| 43 |
+
" raise ValueError(\"OpenAI API key not found.\")\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"# Initialize OpenAI client\n",
|
| 46 |
+
"openai_client = OpenAI(api_key=OPENAI_API_KEY)\n",
|
| 47 |
+
"logging.info('Initialized OpenAI client')\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"# TTS Voice\n",
|
| 50 |
+
"TTS_VOICE = \"en-GB-SoniaNeural\"\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"# Function to check voice activity using Silero VAD\n",
|
| 53 |
+
"def check_vad(audio_data, sample_rate):\n",
|
| 54 |
+
" logging.info('Checking voice activity')\n",
|
| 55 |
+
" # Resample to 16000 Hz if necessary\n",
|
| 56 |
+
" target_sample_rate = 16000\n",
|
| 57 |
+
" if sample_rate != target_sample_rate:\n",
|
| 58 |
+
" resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)\n",
|
| 59 |
+
" audio_tensor = resampler(torch.from_numpy(audio_data))\n",
|
| 60 |
+
" else:\n",
|
| 61 |
+
" audio_tensor = torch.from_numpy(audio_data)\n",
|
| 62 |
+
" audio_tensor = audio_tensor.to(device)\n",
|
| 63 |
+
"\n",
|
| 64 |
+
" # Log audio data details\n",
|
| 65 |
+
" logging.info(f'Audio tensor shape: {audio_tensor.shape}, dtype: {audio_tensor.dtype}, device: {audio_tensor.device}')\n",
|
| 66 |
+
"\n",
|
| 67 |
+
" # Get speech timestamps\n",
|
| 68 |
+
" speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate)\n",
|
| 69 |
+
" logging.info(f'Found {len(speech_timestamps)} speech timestamps')\n",
|
| 70 |
+
" return len(speech_timestamps) > 0\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"# Function to transcribe audio using WhisperX\n",
|
| 73 |
+
"def transcript(audio_data, sample_rate):\n",
|
| 74 |
+
" logging.info('Transcribing audio')\n",
|
| 75 |
+
" # Resample to 16000 Hz if necessary\n",
|
| 76 |
+
" target_sample_rate = 16000\n",
|
| 77 |
+
" if sample_rate != target_sample_rate:\n",
|
| 78 |
+
" resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)\n",
|
| 79 |
+
" audio_data = resampler(torch.from_numpy(audio_data)).numpy()\n",
|
| 80 |
+
" else:\n",
|
| 81 |
+
" audio_data = audio_data\n",
|
| 82 |
+
"\n",
|
| 83 |
+
" # Transcribe\n",
|
| 84 |
+
" batch_size = 16 # Adjust as needed\n",
|
| 85 |
+
" result = whisper_model.transcribe(audio_data, batch_size=batch_size)\n",
|
| 86 |
+
" text = result[\"segments\"][0][\"text\"] if len(result[\"segments\"]) > 0 else \"\"\n",
|
| 87 |
+
" logging.info(f'Transcription result: {text}')\n",
|
| 88 |
+
" # Clear GPU memory\n",
|
| 89 |
+
" del result\n",
|
| 90 |
+
" gc.collect()\n",
|
| 91 |
+
" if device == 'cuda':\n",
|
| 92 |
+
" torch.cuda.empty_cache()\n",
|
| 93 |
+
" return text\n",
|
| 94 |
+
"\n",
|
| 95 |
+
"# Function to get streaming response from OpenAI API\n",
|
| 96 |
+
"def llm(text):\n",
|
| 97 |
+
" logging.info('Getting response from OpenAI API')\n",
|
| 98 |
+
" response = openai_client.chat.completions.create(\n",
|
| 99 |
+
" model=\"gpt-4o\", # Updated to a more recent model\n",
|
| 100 |
+
" messages=[\n",
|
| 101 |
+
" {\"role\": \"system\", \"content\": \"You respond to the following transcript from the conversation that you are having with the user.\"},\n",
|
| 102 |
+
" {\"role\": \"user\", \"content\": text} \n",
|
| 103 |
+
" ],\n",
|
| 104 |
+
" stream=True,\n",
|
| 105 |
+
" temperature=0.7, # Optional: Adjust as needed\n",
|
| 106 |
+
" top_p=0.9, # Optional: Adjust as needed\n",
|
| 107 |
+
" )\n",
|
| 108 |
+
" for chunk in response:\n",
|
| 109 |
+
" yield chunk.choices[0].delta.content\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"# Function to perform TTS per sentence using Edge-TTS\n",
|
| 112 |
+
"def tts_streaming(text_stream):\n",
|
| 113 |
+
" logging.info('Performing TTS')\n",
|
| 114 |
+
" buffer = \"\"\n",
|
| 115 |
+
" punctuation = {'.', '!', '?'}\n",
|
| 116 |
+
" for text_chunk in text_stream:\n",
|
| 117 |
+
" if text_chunk is not None:\n",
|
| 118 |
+
" buffer += text_chunk\n",
|
| 119 |
+
" # Check for sentence completion\n",
|
| 120 |
+
" sentences = []\n",
|
| 121 |
+
" start = 0\n",
|
| 122 |
+
" for i, char in enumerate(buffer):\n",
|
| 123 |
+
" if (char in punctuation):\n",
|
| 124 |
+
" sentences.append(buffer[start:i+1].strip())\n",
|
| 125 |
+
" start = i+1\n",
|
| 126 |
+
" buffer = buffer[start:]\n",
|
| 127 |
+
"\n",
|
| 128 |
+
" for sentence in sentences:\n",
|
| 129 |
+
" if sentence:\n",
|
| 130 |
+
" communicate = edge_tts.Communicate(sentence, TTS_VOICE)\n",
|
| 131 |
+
" for chunk in communicate.stream_sync():\n",
|
| 132 |
+
" if chunk[\"type\"] == \"audio\":\n",
|
| 133 |
+
" yield chunk[\"data\"]\n",
|
| 134 |
+
" # Process any remaining text\n",
|
| 135 |
+
" if buffer.strip():\n",
|
| 136 |
+
" communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)\n",
|
| 137 |
+
" for chunk in communicate.stream_sync():\n",
|
| 138 |
+
" if chunk[\"type\"] == \"audio\":\n",
|
| 139 |
+
" yield chunk[\"data\"]\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"# Function to handle LLM and TTS\n",
|
| 142 |
+
"def llm_and_tts(transcribed_text):\n",
|
| 143 |
+
" logging.info('Handling LLM and TTS')\n",
|
| 144 |
+
" # Get streaming response from LLM\n",
|
| 145 |
+
" for text_chunk in llm(transcribed_text):\n",
|
| 146 |
+
" if state.get('stop_signal'):\n",
|
| 147 |
+
" logging.info('LLM and TTS task stopped')\n",
|
| 148 |
+
" break\n",
|
| 149 |
+
" # Get audio data from TTS\n",
|
| 150 |
+
" for audio_chunk in tts_streaming([text_chunk]):\n",
|
| 151 |
+
" if state.get('stop_signal'):\n",
|
| 152 |
+
" logging.info('LLM and TTS task stopped during TTS')\n",
|
| 153 |
+
" break\n",
|
| 154 |
+
" yield np.frombuffer(audio_chunk, dtype=np.int16)\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"state = {\n",
|
| 157 |
+
" 'mode': 'idle',\n",
|
| 158 |
+
" 'chunk_queue': [],\n",
|
| 159 |
+
" 'transcription': '',\n",
|
| 160 |
+
" 'in_transcription': False,\n",
|
| 161 |
+
" 'previous_no_vad_audio': [],\n",
|
| 162 |
+
" 'llm_task': None,\n",
|
| 163 |
+
" 'instream': None,\n",
|
| 164 |
+
" 'stop_signal': False,\n",
|
| 165 |
+
" 'args': {\n",
|
| 166 |
+
" 'sample_rate': 16000,\n",
|
| 167 |
+
" 'chunk_size': 0.5, # seconds\n",
|
| 168 |
+
" 'transcript_chunk_size': 2, # seconds\n",
|
| 169 |
+
" }\n",
|
| 170 |
+
"}\n",
|
| 171 |
+
"\n",
|
| 172 |
+
"def transcript_loop():\n",
|
| 173 |
+
" while True:\n",
|
| 174 |
+
" if len(state['chunk_queue']) > 0:\n",
|
| 175 |
+
" accumulated_audio = np.concatenate(state['chunk_queue'])\n",
|
| 176 |
+
" total_samples = sum(len(chunk) for chunk in state['chunk_queue'])\n",
|
| 177 |
+
" total_duration = total_samples / state['args']['sample_rate']\n",
|
| 178 |
+
" \n",
|
| 179 |
+
" # Run transcription on the first 2 seconds if len > 3 seconds\n",
|
| 180 |
+
" if total_duration > 3.0 and state['in_transcription'] == True:\n",
|
| 181 |
+
" first_two_seconds_samples = int(2.0 * state['args']['sample_rate'])\n",
|
| 182 |
+
" first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples]\n",
|
| 183 |
+
" transcribed_text = transcript(first_two_seconds_audio, state['args']['sample_rate'])\n",
|
| 184 |
+
" state['transcription'] += transcribed_text\n",
|
| 185 |
+
" remaining_audio = accumulated_audio[first_two_seconds_samples:]\n",
|
| 186 |
+
" state['chunk_queue'] = [remaining_audio]\n",
|
| 187 |
+
" else: # Run transcription on the accumulated audio\n",
|
| 188 |
+
" transcribed_text = transcript(accumulated_audio, state['args']['sample_rate'])\n",
|
| 189 |
+
" state['transcription'] += transcribed_text\n",
|
| 190 |
+
" state['chunk_queue'] = []\n",
|
| 191 |
+
" state['in_transcription'] = False\n",
|
| 192 |
+
" else:\n",
|
| 193 |
+
" time.sleep(0.1)\n",
|
| 194 |
+
"\n",
|
| 195 |
+
" if len(state['chunk_queue']) == 0 and state['mode'] == any(['idle', 'processing']):\n",
|
| 196 |
+
" state['in_transcription'] = False\n",
|
| 197 |
+
" break\n",
|
| 198 |
+
"\n",
|
| 199 |
+
"def process_audio(audio_chunk):\n",
|
| 200 |
+
" # returns output audio\n",
|
| 201 |
+
" \n",
|
| 202 |
+
" sample_rate, audio_data = audio_chunk\n",
|
| 203 |
+
" audio_data = np.array(audio_data, dtype=np.float32)\n",
|
| 204 |
+
" \n",
|
| 205 |
+
" # convert to mono if necessary\n",
|
| 206 |
+
" if audio_data.ndim > 1:\n",
|
| 207 |
+
" audio_data = np.mean(audio_data, axis=1)\n",
|
| 208 |
+
"\n",
|
| 209 |
+
" mode = state['mode']\n",
|
| 210 |
+
" chunk_queue = state['chunk_queue']\n",
|
| 211 |
+
" transcription = state['transcription']\n",
|
| 212 |
+
" in_transcription = state['in_transcription']\n",
|
| 213 |
+
" previous_no_vad_audio = state['previous_no_vad_audio']\n",
|
| 214 |
+
" llm_task = state['llm_task']\n",
|
| 215 |
+
" instream = state['instream']\n",
|
| 216 |
+
" stop_signal = state['stop_signal']\n",
|
| 217 |
+
" args = state['args']\n",
|
| 218 |
+
" \n",
|
| 219 |
+
" args['sample_rate'] = sample_rate\n",
|
| 220 |
+
" \n",
|
| 221 |
+
" # check for voice activity\n",
|
| 222 |
+
" vad = check_vad(audio_data, sample_rate)\n",
|
| 223 |
+
" \n",
|
| 224 |
+
" if vad:\n",
|
| 225 |
+
" logging.info(f'Voice activity detected in mode: {mode}')\n",
|
| 226 |
+
" if mode == 'idle':\n",
|
| 227 |
+
" mode = 'listening'\n",
|
| 228 |
+
" elif mode == 'speaking':\n",
|
| 229 |
+
" # Stop llm and tts tasks\n",
|
| 230 |
+
" if llm_task and llm_task.is_alive():\n",
|
| 231 |
+
" # Implement task cancellation logic if possible\n",
|
| 232 |
+
" logging.info('Stopping LLM and TTS tasks')\n",
|
| 233 |
+
" # Since we cannot kill threads directly, we need to handle this in the tasks\n",
|
| 234 |
+
" stop_signal = True\n",
|
| 235 |
+
" llm_task.join()\n",
|
| 236 |
+
" mode = 'listening'\n",
|
| 237 |
+
"\n",
|
| 238 |
+
" if mode == 'listening':\n",
|
| 239 |
+
" if previous_no_vad_audio is not None:\n",
|
| 240 |
+
" chunk_queue.append(previous_no_vad_audio)\n",
|
| 241 |
+
" previous_no_vad_audio = None\n",
|
| 242 |
+
" # Accumulate audio chunks\n",
|
| 243 |
+
" chunk_queue.append(audio_data)\n",
|
| 244 |
+
" \n",
|
| 245 |
+
" # Start transcription thread if not already running\n",
|
| 246 |
+
" if not in_transcription:\n",
|
| 247 |
+
" in_transcription = True\n",
|
| 248 |
+
" transcription_task = threading.Thread(target=transcript_loop)\n",
|
| 249 |
+
" transcription_task.start()\n",
|
| 250 |
+
" \n",
|
| 251 |
+
" elif mode == 'speaking':\n",
|
| 252 |
+
" # Continue accumulating audio chunks\n",
|
| 253 |
+
" chunk_queue.append(audio_data)\n",
|
| 254 |
+
" else:\n",
|
| 255 |
+
" logging.info(f'No voice activity detected in mode: {mode}')\n",
|
| 256 |
+
" if mode == 'listening':\n",
|
| 257 |
+
" # Add the last chunk to queue\n",
|
| 258 |
+
" chunk_queue.append(audio_data)\n",
|
| 259 |
+
" \n",
|
| 260 |
+
" # Change mode to processing\n",
|
| 261 |
+
" mode = 'processing'\n",
|
| 262 |
+
" \n",
|
| 263 |
+
" # Wait for transcription to complete\n",
|
| 264 |
+
" while in_transcription:\n",
|
| 265 |
+
" time.sleep(0.1)\n",
|
| 266 |
+
" \n",
|
| 267 |
+
" # Check if transcription is complete\n",
|
| 268 |
+
" if len(chunk_queue) == 0:\n",
|
| 269 |
+
" # Start LLM and TTS tasks\n",
|
| 270 |
+
" if not llm_task or not llm_task.is_alive():\n",
|
| 271 |
+
" stop_signal = False\n",
|
| 272 |
+
" llm_task = threading.Thread(target=llm_and_tts, args=(transcription))\n",
|
| 273 |
+
" llm_task.start()\n",
|
| 274 |
+
" \n",
|
| 275 |
+
" if mode == 'processing':\n",
|
| 276 |
+
" # Wait for LLM and TTS tasks to start yielding audio\n",
|
| 277 |
+
" if llm_task and llm_task.is_alive():\n",
|
| 278 |
+
" mode = 'responding'\n",
|
| 279 |
+
" \n",
|
| 280 |
+
" if mode == 'responding':\n",
|
| 281 |
+
" for audio_chunk in llm_task:\n",
|
| 282 |
+
" if instream is None:\n",
|
| 283 |
+
" instream = audio_chunk\n",
|
| 284 |
+
" else:\n",
|
| 285 |
+
" instream = np.concatenate((instream, audio_chunk))\n",
|
| 286 |
+
" \n",
|
| 287 |
+
" # Send audio to output stream\n",
|
| 288 |
+
" yield instream\n",
|
| 289 |
+
" \n",
|
| 290 |
+
" # Cleanup\n",
|
| 291 |
+
" llm_task = None\n",
|
| 292 |
+
" transcription = ''\n",
|
| 293 |
+
" mode = 'idle'\n",
|
| 294 |
+
" \n",
|
| 295 |
+
" # Updaate state\n",
|
| 296 |
+
" state['mode'] = mode\n",
|
| 297 |
+
" state['chunk_queue'] = chunk_queue\n",
|
| 298 |
+
" state['transcription'] = transcription\n",
|
| 299 |
+
" state['in_transcription'] = in_transcription\n",
|
| 300 |
+
" state['previous_no_vad_audio'] = previous_no_vad_audio\n",
|
| 301 |
+
" state['llm_task'] = llm_task\n",
|
| 302 |
+
" state['instream'] = instream\n",
|
| 303 |
+
" state['stop_signal'] = stop_signal\n",
|
| 304 |
+
" state['args'] = args\n",
|
| 305 |
+
" \n",
|
| 306 |
+
" # Store previous audio chunk with no voice activity\n",
|
| 307 |
+
" previous_no_vad_audio = audio_data\n",
|
| 308 |
+
" \n",
|
| 309 |
+
" # Update state\n",
|
| 310 |
+
" state['mode'] = mode\n",
|
| 311 |
+
" state['chunk_queue'] = chunk_queue\n",
|
| 312 |
+
" state['transcription'] = transcription\n",
|
| 313 |
+
" state['in_transcription'] = in_transcription\n",
|
| 314 |
+
" state['previous_no_vad_audio'] = previous_no_vad_audio\n",
|
| 315 |
+
" state['llm_task'] = llm_task\n",
|
| 316 |
+
" state['instream'] = instream\n",
|
| 317 |
+
" state['stop_signal'] = stop_signal\n",
|
| 318 |
+
" state['args'] = args"
|
| 319 |
+
]
|
| 320 |
+
},
|
| 321 |
+
{
|
| 322 |
+
"cell_type": "code",
|
| 323 |
+
"execution_count": null,
|
| 324 |
+
"metadata": {},
|
| 325 |
+
"outputs": [],
|
| 326 |
+
"source": [
|
| 327 |
+
"# 1. Load audio.mp3\n",
|
| 328 |
+
"# 2. Split audio into chunks\n",
|
| 329 |
+
"# 3. Process each chunk inside a loop\n",
|
| 330 |
+
"\n",
|
| 331 |
+
"# Split audio into chunks of 500 ms or less\n",
|
| 332 |
+
"from pydub import AudioSegment\n",
|
| 333 |
+
"audio_segment = AudioSegment.from_file('audio.mp3')\n",
|
| 334 |
+
"chunks = [chunk for chunk in audio_segment[::500]]\n",
|
| 335 |
+
"chunks[0]\n",
|
| 336 |
+
"chunks = [(chunk.frame_rate, np.array(chunk.get_array_of_samples(), dtype=np.int16)) for chunk in chunks]\n",
|
| 337 |
+
"\n",
|
| 338 |
+
"output_audio = []\n",
|
| 339 |
+
"# Process each chunk\n",
|
| 340 |
+
"for chunk in chunks:\n",
|
| 341 |
+
" for audio_chunk in process_audio(chunk):\n",
|
| 342 |
+
" output_audio.append(audio_chunk)"
|
| 343 |
+
]
|
| 344 |
+
},
|
| 345 |
+
{
|
| 346 |
+
"cell_type": "code",
|
| 347 |
+
"execution_count": null,
|
| 348 |
+
"metadata": {},
|
| 349 |
+
"outputs": [],
|
| 350 |
+
"source": [
|
| 351 |
+
"output_audio"
|
| 352 |
+
]
|
| 353 |
+
},
|
| 354 |
+
{
|
| 355 |
+
"cell_type": "code",
|
| 356 |
+
"execution_count": null,
|
| 357 |
+
"metadata": {},
|
| 358 |
+
"outputs": [],
|
| 359 |
+
"source": [
|
| 360 |
+
"import asyncio\n",
|
| 361 |
+
"import websockets\n",
|
| 362 |
+
"from pydub import AudioSegment\n",
|
| 363 |
+
"import numpy as np\n",
|
| 364 |
+
"import simpleaudio as sa\n",
|
| 365 |
+
"\n",
|
| 366 |
+
"# Constants\n",
|
| 367 |
+
"AUDIO_FILE = 'audio.mp3' # Input audio file\n",
|
| 368 |
+
"CHUNK_DURATION_MS = 250 # Duration of each chunk in milliseconds\n",
|
| 369 |
+
"WEBSOCKET_URI = 'ws://localhost:8000/ws' # WebSocket endpoint\n",
|
| 370 |
+
"\n",
|
| 371 |
+
"async def send_audio_chunks(uri):\n",
|
| 372 |
+
" # Load audio file using pydub\n",
|
| 373 |
+
" audio = AudioSegment.from_file(AUDIO_FILE)\n",
|
| 374 |
+
"\n",
|
| 375 |
+
" # Ensure audio is mono and 16kHz\n",
|
| 376 |
+
" if audio.channels > 1:\n",
|
| 377 |
+
" audio = audio.set_channels(1)\n",
|
| 378 |
+
" if audio.frame_rate != 16000:\n",
|
| 379 |
+
" audio = audio.set_frame_rate(16000)\n",
|
| 380 |
+
" if audio.sample_width != 2: # 2 bytes for int16\n",
|
| 381 |
+
" audio = audio.set_sample_width(2)\n",
|
| 382 |
+
"\n",
|
| 383 |
+
" # Split audio into chunks\n",
|
| 384 |
+
" chunks = [audio[i:i+CHUNK_DURATION_MS] for i in range(0, len(audio), CHUNK_DURATION_MS)]\n",
|
| 385 |
+
"\n",
|
| 386 |
+
" # Store received audio data\n",
|
| 387 |
+
" received_audio_data = b''\n",
|
| 388 |
+
"\n",
|
| 389 |
+
" async with websockets.connect(uri) as websocket:\n",
|
| 390 |
+
" print(\"Connected to server.\")\n",
|
| 391 |
+
" for idx, chunk in enumerate(chunks):\n",
|
| 392 |
+
" # Get raw audio data\n",
|
| 393 |
+
" raw_data = chunk.raw_data\n",
|
| 394 |
+
"\n",
|
| 395 |
+
" # Send audio chunk to server\n",
|
| 396 |
+
" await websocket.send(raw_data)\n",
|
| 397 |
+
" print(f\"Sent chunk {idx+1}/{len(chunks)}\")\n",
|
| 398 |
+
"\n",
|
| 399 |
+
" # Receive response (non-blocking)\n",
|
| 400 |
+
" try:\n",
|
| 401 |
+
" response = await asyncio.wait_for(websocket.recv(), timeout=0.1)\n",
|
| 402 |
+
" if isinstance(response, bytes):\n",
|
| 403 |
+
" received_audio_data += response\n",
|
| 404 |
+
" print(f\"Received audio data of length {len(response)} bytes\")\n",
|
| 405 |
+
" except asyncio.TimeoutError:\n",
|
| 406 |
+
" pass # No response received yet\n",
|
| 407 |
+
"\n",
|
| 408 |
+
" # Simulate real-time by waiting for chunk duration\n",
|
| 409 |
+
" await asyncio.sleep(CHUNK_DURATION_MS / 1000.0)\n",
|
| 410 |
+
"\n",
|
| 411 |
+
" # Send a final empty message to indicate end of transmission\n",
|
| 412 |
+
" await websocket.send(b'')\n",
|
| 413 |
+
" print(\"Finished sending audio. Waiting for responses...\")\n",
|
| 414 |
+
"\n",
|
| 415 |
+
" # Receive any remaining responses\n",
|
| 416 |
+
" while True:\n",
|
| 417 |
+
" try:\n",
|
| 418 |
+
" response = await asyncio.wait_for(websocket.recv(), timeout=1)\n",
|
| 419 |
+
" if isinstance(response, bytes):\n",
|
| 420 |
+
" received_audio_data += response\n",
|
| 421 |
+
" print(f\"Received audio data of length {len(response)} bytes\")\n",
|
| 422 |
+
" except asyncio.TimeoutError:\n",
|
| 423 |
+
" print(\"No more responses. Closing connection.\")\n",
|
| 424 |
+
" break\n",
|
| 425 |
+
"\n",
|
| 426 |
+
" print(\"Connection closed.\")\n",
|
| 427 |
+
"\n",
|
| 428 |
+
" # Save received audio data to a file or play it\n",
|
| 429 |
+
" if received_audio_data:\n",
|
| 430 |
+
" # Convert bytes to numpy array\n",
|
| 431 |
+
" audio_array = np.frombuffer(received_audio_data, dtype=np.int16)\n",
|
| 432 |
+
"\n",
|
| 433 |
+
" # Play audio using simpleaudio\n",
|
| 434 |
+
" play_obj = sa.play_buffer(audio_array, 1, 2, 16000)\n",
|
| 435 |
+
" play_obj.wait_done()\n",
|
| 436 |
+
"\n",
|
| 437 |
+
" # Optionally, save to a WAV file\n",
|
| 438 |
+
" output_audio = AudioSegment(\n",
|
| 439 |
+
" data=received_audio_data,\n",
|
| 440 |
+
" sample_width=2, # 2 bytes for int16\n",
|
| 441 |
+
" frame_rate=16000,\n",
|
| 442 |
+
" channels=1\n",
|
| 443 |
+
" )\n",
|
| 444 |
+
" output_audio.export(\"output_response.wav\", format=\"wav\")\n",
|
| 445 |
+
" print(\"Saved response audio to 'output_response.wav'\")\n",
|
| 446 |
+
" else:\n",
|
| 447 |
+
" print(\"No audio data received.\")\n",
|
| 448 |
+
"\n",
|
| 449 |
+
"def main():\n",
|
| 450 |
+
" asyncio.run(send_audio_chunks(WEBSOCKET_URI))\n",
|
| 451 |
+
"\n",
|
| 452 |
+
"if __name__ == '__main__':\n",
|
| 453 |
+
" main()"
|
| 454 |
+
]
|
| 455 |
+
}
|
| 456 |
+
],
|
| 457 |
+
"metadata": {
|
| 458 |
+
"kernelspec": {
|
| 459 |
+
"display_name": ".venv",
|
| 460 |
+
"language": "python",
|
| 461 |
+
"name": "python3"
|
| 462 |
+
},
|
| 463 |
+
"language_info": {
|
| 464 |
+
"codemirror_mode": {
|
| 465 |
+
"name": "ipython",
|
| 466 |
+
"version": 3
|
| 467 |
+
},
|
| 468 |
+
"file_extension": ".py",
|
| 469 |
+
"mimetype": "text/x-python",
|
| 470 |
+
"name": "python",
|
| 471 |
+
"nbconvert_exporter": "python",
|
| 472 |
+
"pygments_lexer": "ipython3",
|
| 473 |
+
"version": "3.10.12"
|
| 474 |
+
}
|
| 475 |
+
},
|
| 476 |
+
"nbformat": 4,
|
| 477 |
+
"nbformat_minor": 2
|
| 478 |
+
}
|
playground/testapp/test.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import fastapi
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
from silero_vad import get_speech_timestamps, load_silero_vad
|
| 6 |
+
import whisperx
|
| 7 |
+
import edge_tts
|
| 8 |
+
import gc
|
| 9 |
+
import logging
|
| 10 |
+
import time
|
| 11 |
+
import os
|
| 12 |
+
from openai import OpenAI
|
| 13 |
+
import asyncio
|
| 14 |
+
from pydub import AudioSegment
|
| 15 |
+
from io import BytesIO
|
| 16 |
+
import threading
|
| 17 |
+
|
| 18 |
+
# Configure logging
|
| 19 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 20 |
+
|
| 21 |
+
# Configure FastAPI
|
| 22 |
+
app = fastapi.FastAPI()
|
| 23 |
+
|
| 24 |
+
# Load Silero VAD model
|
| 25 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 26 |
+
logging.info(f'Using device: {device}')
|
| 27 |
+
vad_model = load_silero_vad().to(device)
|
| 28 |
+
logging.info('Loaded Silero VAD model')
|
| 29 |
+
|
| 30 |
+
# Load WhisperX model
|
| 31 |
+
whisper_model = whisperx.load_model("tiny", device, compute_type="float16")
|
| 32 |
+
logging.info('Loaded WhisperX model')
|
| 33 |
+
|
| 34 |
+
OPENAI_API_KEY = "sk-proj-gcrtuxd5qzaRYT82Ii3eT3BlbkFJpVQHBc9ZJrmSksLbQc3C"
|
| 35 |
+
if not OPENAI_API_KEY:
|
| 36 |
+
logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
|
| 37 |
+
raise ValueError("OpenAI API key not found.")
|
| 38 |
+
logging.info('Initialized OpenAI client')
|
| 39 |
+
llm_client = OpenAI(api_key=OPENAI_API_KEY) # Corrected import
|
| 40 |
+
|
| 41 |
+
# TTS Voice
|
| 42 |
+
TTS_VOICE = "en-GB-SoniaNeural"
|
| 43 |
+
|
| 44 |
+
# Function to check voice activity using Silero VAD
|
| 45 |
+
def check_vad(audio_data, sample_rate):
|
| 46 |
+
logging.info('Checking voice activity')
|
| 47 |
+
target_sample_rate = 16000
|
| 48 |
+
if sample_rate != target_sample_rate:
|
| 49 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
|
| 50 |
+
audio_tensor = resampler(torch.from_numpy(audio_data))
|
| 51 |
+
else:
|
| 52 |
+
audio_tensor = torch.from_numpy(audio_data)
|
| 53 |
+
audio_tensor = audio_tensor.to(device)
|
| 54 |
+
|
| 55 |
+
speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate)
|
| 56 |
+
logging.info(f'Found {len(speech_timestamps)} speech timestamps')
|
| 57 |
+
return len(speech_timestamps) > 0
|
| 58 |
+
|
| 59 |
+
# Async function to transcribe audio using WhisperX
|
| 60 |
+
def transcribe(audio_data, sample_rate):
|
| 61 |
+
logging.info('Transcribing audio')
|
| 62 |
+
target_sample_rate = 16000
|
| 63 |
+
if sample_rate != target_sample_rate:
|
| 64 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
|
| 65 |
+
audio_data = resampler(torch.from_numpy(audio_data)).numpy()
|
| 66 |
+
else:
|
| 67 |
+
audio_data = audio_data
|
| 68 |
+
|
| 69 |
+
batch_size = 16 # Adjust as needed
|
| 70 |
+
result = whisper_model.transcribe(audio_data, batch_size=batch_size)
|
| 71 |
+
text = result["segments"][0]["text"] if len(result["segments"]) > 0 else ""
|
| 72 |
+
logging.info(f'Transcription result: {text}')
|
| 73 |
+
del result
|
| 74 |
+
gc.collect()
|
| 75 |
+
if device == 'cuda':
|
| 76 |
+
torch.cuda.empty_cache()
|
| 77 |
+
return text
|
| 78 |
+
|
| 79 |
+
# Function to convert text to speech using Edge TTS and stream the audio
|
| 80 |
+
def tts_streaming(text_stream):
|
| 81 |
+
logging.info('Performing TTS')
|
| 82 |
+
buffer = ""
|
| 83 |
+
punctuation = {'.', '!', '?'}
|
| 84 |
+
for text_chunk in text_stream:
|
| 85 |
+
if text_chunk is not None:
|
| 86 |
+
buffer += text_chunk
|
| 87 |
+
# Check for sentence completion
|
| 88 |
+
sentences = []
|
| 89 |
+
start = 0
|
| 90 |
+
for i, char in enumerate(buffer):
|
| 91 |
+
if char in punctuation:
|
| 92 |
+
sentences.append(buffer[start:i+1].strip())
|
| 93 |
+
start = i+1
|
| 94 |
+
buffer = buffer[start:]
|
| 95 |
+
|
| 96 |
+
for sentence in sentences:
|
| 97 |
+
if sentence:
|
| 98 |
+
communicate = edge_tts.Communicate(sentence, TTS_VOICE)
|
| 99 |
+
for chunk in communicate.stream_sync():
|
| 100 |
+
if chunk["type"] == "audio":
|
| 101 |
+
yield chunk["data"]
|
| 102 |
+
# Process any remaining text
|
| 103 |
+
if buffer.strip():
|
| 104 |
+
communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)
|
| 105 |
+
for chunk in communicate.stream_sync():
|
| 106 |
+
if chunk["type"] == "audio":
|
| 107 |
+
yield chunk["data"]
|
| 108 |
+
|
| 109 |
+
# Function to perform language model completion using OpenAI API
|
| 110 |
+
def llm(text):
|
| 111 |
+
logging.info('Getting response from OpenAI API')
|
| 112 |
+
response = llm_client.chat.completions.create(
|
| 113 |
+
model="gpt-4o", # Updated to a more recent model
|
| 114 |
+
messages=[
|
| 115 |
+
{"role": "system", "content": "You respond to the following transcript from the conversation that you are having with the user."},
|
| 116 |
+
{"role": "user", "content": text}
|
| 117 |
+
],
|
| 118 |
+
stream=True,
|
| 119 |
+
temperature=0.7,
|
| 120 |
+
top_p=0.9
|
| 121 |
+
)
|
| 122 |
+
for chunk in response:
|
| 123 |
+
yield chunk.choices[0].delta.content
|
| 124 |
+
|
| 125 |
+
class Conversation:
|
| 126 |
+
def __init__(self):
|
| 127 |
+
self.mode = 'idle' # idle, listening, speaking
|
| 128 |
+
self.audio_stream = []
|
| 129 |
+
self.valid_chunk_queue = []
|
| 130 |
+
self.first_valid_chunk = None
|
| 131 |
+
self.last_valid_chunks = []
|
| 132 |
+
self.valid_chunk_transcriptions = ''
|
| 133 |
+
self.in_transcription = False
|
| 134 |
+
self.llm_n_tts_task = None
|
| 135 |
+
self.stop_signal = False
|
| 136 |
+
self.sample_rate = 0
|
| 137 |
+
self.out_audio_stream = []
|
| 138 |
+
self.chunk_buffer = 0.5 # seconds
|
| 139 |
+
|
| 140 |
+
def llm_n_tts(self):
|
| 141 |
+
for text_chunk in llm(self.transcription):
|
| 142 |
+
if self.stop_signal:
|
| 143 |
+
break
|
| 144 |
+
for audio_chunk in tts_streaming([text_chunk]):
|
| 145 |
+
if self.stop_signal:
|
| 146 |
+
break
|
| 147 |
+
self.out_audio_stream.append(np.frombuffer(audio_chunk, dtype=np.int16))
|
| 148 |
+
|
| 149 |
+
def process_audio_chunk(self, audio_chunk):
|
| 150 |
+
# Construct audio stream
|
| 151 |
+
audio_data = AudioSegment.from_file(BytesIO(audio_chunk), format="wav")
|
| 152 |
+
audio_data = np.array(audio_data.get_array_of_samples())
|
| 153 |
+
self.sample_rate = audio_data.frame_rate
|
| 154 |
+
|
| 155 |
+
# Check for voice activity
|
| 156 |
+
vad = check_vad(audio_data, self.sample_rate)
|
| 157 |
+
|
| 158 |
+
if vad: # Voice activity detected
|
| 159 |
+
if self.first_valid_chunk is not None:
|
| 160 |
+
self.valid_chunk_queue.append(self.first_valid_chunk)
|
| 161 |
+
self.first_valid_chunk = None
|
| 162 |
+
self.valid_chunk_queue.append(audio_chunk)
|
| 163 |
+
|
| 164 |
+
if len(self.valid_chunk_queue) > 2:
|
| 165 |
+
# i.e. 3 chunks: 1 non valid chunk + 2 valid chunks
|
| 166 |
+
# this is to ensure that the speaker is speaking
|
| 167 |
+
if self.mode == 'idle':
|
| 168 |
+
self.mode = 'listening'
|
| 169 |
+
elif self.mode == 'speaking':
|
| 170 |
+
# Stop llm and tts
|
| 171 |
+
if self.llm_n_tts_task is not None:
|
| 172 |
+
self.stop_signal = True
|
| 173 |
+
self.llm_n_tts_task
|
| 174 |
+
self.stop_signal = False
|
| 175 |
+
self.mode = 'listening'
|
| 176 |
+
|
| 177 |
+
else: # No voice activity
|
| 178 |
+
if self.mode == 'listening':
|
| 179 |
+
self.last_valid_chunks.append(audio_chunk)
|
| 180 |
+
|
| 181 |
+
if len(self.last_valid_chunks) > 2:
|
| 182 |
+
# i.e. 2 chunks where the speaker stopped speaking, but we account for natural pauses
|
| 183 |
+
# so on the 1.5th second of no voice activity, we append the first 2 of the last valid chunks to the valid chunk queue
|
| 184 |
+
# stop listening and start speaking
|
| 185 |
+
self.valid_chunk_queue.extend(self.last_valid_chunks[:2])
|
| 186 |
+
self.last_valid_chunks = []
|
| 187 |
+
|
| 188 |
+
while len(self.valid_chunk_queue) > 0:
|
| 189 |
+
time.sleep(0.1)
|
| 190 |
+
|
| 191 |
+
self.mode = 'speaking'
|
| 192 |
+
self.llm_n_tts_task = threading.Thread(target=self.llm_n_tts)
|
| 193 |
+
self.llm_n_tts_task.start()
|
| 194 |
+
|
| 195 |
+
def transcribe_loop(self):
|
| 196 |
+
while True:
|
| 197 |
+
if self.mode == 'listening':
|
| 198 |
+
if len(self.valid_chunk_queue) > 0:
|
| 199 |
+
accumulated_chunks = np.concatenate(self.valid_chunk_queue)
|
| 200 |
+
total_duration = len(accumulated_chunks) / self.sample_rate
|
| 201 |
+
|
| 202 |
+
if total_duration >= 3.0 and self.in_transcription == True:
|
| 203 |
+
# i.e. we have at least 3 seconds of audio so we can start transcribing to reduce latency
|
| 204 |
+
first_2s_audio = accumulated_chunks[:int(2 * self.sample_rate)]
|
| 205 |
+
transcribed_text = transcribe(first_2s_audio, self.sample_rate)
|
| 206 |
+
self.valid_chunk_transcriptions += transcribed_text
|
| 207 |
+
self.valid_chunk_queue = [accumulated_chunks[int(2 * self.sample_rate):]]
|
| 208 |
+
|
| 209 |
+
if self.mode == any(['idle', 'speaking']):
|
| 210 |
+
# i.e. the request to stop transcription has been made
|
| 211 |
+
# so process the remaining audio
|
| 212 |
+
transcribed_text = transcribe(accumulated_chunks, self.sample_rate)
|
| 213 |
+
self.valid_chunk_transcriptions += transcribed_text
|
| 214 |
+
self.valid_chunk_queue = []
|
| 215 |
+
else:
|
| 216 |
+
time.sleep(0.1)
|
| 217 |
+
|
| 218 |
+
def stream_out_audio(self):
|
| 219 |
+
while True:
|
| 220 |
+
if len(self.out_audio_stream) > 0:
|
| 221 |
+
yield AudioSegment(data=self.out_audio_stream.pop(0), sample_width=2, frame_rate=self.sample_rate, channels=1).raw_data
|
| 222 |
+
|
| 223 |
+
@app.websocket("/ws")
|
| 224 |
+
async def websocket_endpoint(websocket: fastapi.WebSocket):
|
| 225 |
+
# Accept connection
|
| 226 |
+
await websocket.accept()
|
| 227 |
+
|
| 228 |
+
# Initialize conversation
|
| 229 |
+
conversation = Conversation()
|
| 230 |
+
|
| 231 |
+
# Start conversation threads
|
| 232 |
+
transcribe_thread = threading.Thread(target=conversation.transcribe_loop)
|
| 233 |
+
transcribe_thread.start()
|
| 234 |
+
|
| 235 |
+
# Process audio chunks
|
| 236 |
+
chunk_buffer_size = conversation.chunk_buffer
|
| 237 |
+
while True:
|
| 238 |
+
try:
|
| 239 |
+
audio_chunk = await websocket.receive_bytes()
|
| 240 |
+
conversation.process_audio_chunk(audio_chunk)
|
| 241 |
+
|
| 242 |
+
if conversation.mode == 'speaking':
|
| 243 |
+
for audio_chunk in conversation.stream_out_audio():
|
| 244 |
+
await websocket.send_bytes(audio_chunk)
|
| 245 |
+
else:
|
| 246 |
+
await websocket.send_bytes(b'')
|
| 247 |
+
except Exception as e:
|
| 248 |
+
logging.error(e)
|
| 249 |
+
break
|
| 250 |
+
|
| 251 |
+
@app.get("/")
|
| 252 |
+
async def index():
|
| 253 |
+
return fastapi.responses.FileResponse("index.html")
|
| 254 |
+
|
| 255 |
+
if __name__ == '__main__':
|
| 256 |
+
import uvicorn
|
| 257 |
+
uvicorn.run(app, host='0.0.0.0', port=8000)
|