voice-agent-examples / web-ui /src /hooks /useACEController.ts
fciannella's picture
Added the healthcare example
2f49513
raw
history blame
4.4 kB
/*
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { useRef, useState } from "react";
import logger from "../utils/logger";
import extractWavSampleRate from "../utils/extractWavSampleRate";
import pcmToWav from "../utils/pcmToWav";
interface Params {
url: string;
onError: (error: Error) => void;
onAudioChunk: (chunk: AudioBuffer) => void;
onTTS: (transcript: string) => void;
onASR: (transcript: string) => void;
}
interface Output {
connectionStatus: ConnectionStatus;
connect: () => void;
sendAudioChunk: (
chunk: ArrayBuffer,
sampleRate: number,
numChannels: number
) => void;
}
type ConnectionStatus = "disconnected" | "connected" | "connecting";
interface TTSMessage {
type: "tts_update";
tts: string;
}
interface ASRMessage {
type: "asr_update";
asr: string;
}
export default function useACEController(params: Params): Output {
const websocketRef = useRef<WebSocket>(null);
const audioCtxRef = useRef<AudioContext>(null);
const [connectionStatus, setConnectionStatus] =
useState<ConnectionStatus>("disconnected");
function onError(error: Error) {
setConnectionStatus("disconnected");
params.onError(error);
websocketRef.current?.close(1000);
websocketRef.current = null;
}
function onOpen() {
setConnectionStatus("connected");
}
function onClose(e: CloseEvent): void {
if (e.wasClean) {
onError(new Error("Websocket closed unexpectedly"));
}
setConnectionStatus("disconnected");
websocketRef.current = null;
}
function onWindowUnload() {
setConnectionStatus("disconnected");
websocketRef.current?.close(1000);
}
async function handleAudioMessage(data: ArrayBuffer) {
try {
const sampleRate = extractWavSampleRate(data);
if (
!audioCtxRef.current ||
audioCtxRef.current.sampleRate !== sampleRate
) {
audioCtxRef.current = new AudioContext({ sampleRate });
}
const audioBuffer = await audioCtxRef.current.decodeAudioData(data);
params.onAudioChunk(audioBuffer);
} catch (error) {
logger.warn("Error decoding audio chunk. The chunk was discarded", error);
}
}
function handleTTSUpdate(data: TTSMessage) {
params.onTTS(data.tts);
}
function handleASRUpdate(data: ASRMessage) {
params.onASR(data.asr);
}
async function onMessage(event: MessageEvent): Promise<void> {
if (event.data instanceof ArrayBuffer) {
handleAudioMessage(event.data);
return;
}
const data = JSON.parse(event.data);
switch (data.type) {
case "tts_update":
handleTTSUpdate(data);
break;
case "asr_update":
handleASRUpdate(data);
break;
default:
logger.warn("Unrecognized message. Discarded", data);
}
}
function connect() {
if (!websocketRef.current) {
setConnectionStatus("connecting");
const ws = new WebSocket(params.url);
websocketRef.current = ws;
ws.binaryType = "arraybuffer";
ws.onmessage = onMessage;
ws.onopen = onOpen;
ws.onclose = onClose;
ws.onerror = () =>
onError(
new Error(
`Failed to establish a websocket connection. Is the ACE Controller running at ${params.url}?`
)
);
window.onbeforeunload = onWindowUnload;
}
}
function sendAudioChunk(
chunk: ArrayBuffer,
sampleRate: number,
numChannels: number
): void {
if (websocketRef.current?.readyState !== WebSocket.OPEN) {
logger.warn("Websocket is not open. Discarding audio chunk");
return;
}
websocketRef.current.send(pcmToWav(chunk, sampleRate, numChannels));
}
return { connectionStatus, connect, sendAudioChunk };
}