fciannella's picture
Working with service run on 7860
53ea588
raw
history blame
17.4 kB
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD 2-Clause License
"""OpenAI Realtime speech services (ASR/TTS).
Currently implements streaming Speech-to-Text (ASR) over OpenAI's Realtime WebSocket API
using server-side VAD and input transcription events. TTS class is provided as a stub.
"""
from __future__ import annotations
import asyncio
import audioop
import base64
import json
import os
from collections.abc import AsyncGenerator
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
ErrorFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.services.stt_service import STTService
from pipecat.services.tts_service import TTSService
from pipecat.utils.time import time_now_iso8601
from nvidia_pipecat.utils.tracing import AttachmentStrategy, traceable, traced
def _get_env(name: str, default: str | None = None) -> str | None:
value = os.getenv(name)
return value if value is not None else default
@traceable
class OpenAIRealtimeASRService(STTService):
"""OpenAI Realtime ASR (Speech-to-Text) service.
Streams 16 kHz mono PCM16 audio chunks to the OpenAI Realtime API and emits
interim and final transcription frames.
"""
def __init__(
self,
*,
api_key: str | None = None,
realtime_model: str | None = None,
stt_model: str | None = None,
sample_rate: int = 16000,
audio_channel_count: int = 1,
server_vad: bool = True,
vad_threshold: float | None = 0.5,
vad_prefix_padding_ms: int | None = 300,
vad_silence_duration_ms: int | None = 600,
ws_ping_interval: int = 20,
ws_ping_timeout: int = 20,
**kwargs,
):
super().__init__(**kwargs)
# Configuration
self._api_key = api_key or _get_env("OPENAI_API_KEY")
if not self._api_key:
logger.warning("OpenAIRealtimeASRService: OPENAI_API_KEY not set; service will fail to connect.")
self._realtime_model = (
realtime_model
or _get_env("OPENAI_REALTIME_MODEL", "gpt-4o-realtime-preview")
or "gpt-4o-realtime-preview"
)
self._stt_model = stt_model or _get_env("OPENAI_STT_MODEL", "gpt-4o-transcribe") or "gpt-4o-transcribe"
self._sample_rate = sample_rate
self._audio_channel_count = audio_channel_count
self._server_vad = server_vad
self._vad_threshold = vad_threshold
self._vad_prefix_padding_ms = vad_prefix_padding_ms
self._vad_silence_duration_ms = vad_silence_duration_ms
self._ws_ping_interval = ws_ping_interval
self._ws_ping_timeout = ws_ping_timeout
# Runtime state
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
self._ws = None
self._ws_task: asyncio.Task | None = None
self._sender_task: asyncio.Task | None = None
self._receiver_task: asyncio.Task | None = None
self._partial_buffer: str = ""
# Construct WS URL and headers
self._ws_url = f"wss://api.openai.com/v1/realtime?model={self._realtime_model}"
self._ws_headers = [
("Authorization", f"Bearer {self._api_key}"),
("OpenAI-Beta", "realtime=v1"),
]
def can_generate_metrics(self) -> bool:
return False
async def start(self, frame: StartFrame):
await super().start(frame)
# Start websocket main loop in background
if self._ws_task is None or self._ws_task.done():
self._ws_task = self.create_task(self._ws_main())
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._shutdown()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._shutdown()
async def _shutdown(self):
# Cancel sender/receiver first
if self._sender_task and not self._sender_task.done():
await self.cancel_task(self._sender_task)
if self._receiver_task and not self._receiver_task.done():
await self.cancel_task(self._receiver_task)
self._sender_task = None
self._receiver_task = None
# Close websocket by cancelling _ws_task (which owns the connection)
if self._ws_task and not self._ws_task.done():
await self.cancel_task(self._ws_task)
self._ws_task = None
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
# Enqueue audio for the sender task
try:
await self._audio_queue.put(audio)
except Exception as e:
logger.error(f"{self.name} failed to enqueue audio: {e}")
# Nothing to yield immediately (streaming handled in background)
yield None
async def _ws_main(self):
"""Owns the websocket connection and spawns sender/receiver tasks."""
try:
# Late import to avoid hard dependency errors during initialization
import websockets
kwargs = {
"extra_headers": self._ws_headers,
"ping_interval": self._ws_ping_interval,
"ping_timeout": self._ws_ping_timeout,
"max_size": 10_000_000,
}
logger.info(f"Connecting to OpenAI Realtime WS model={self._realtime_model}")
async with websockets.connect(self._ws_url, **kwargs) as ws:
self._ws = ws
await self._configure_session()
# Spawn sender and receiver tasks bound to this ws
self._sender_task = self.create_task(self._ws_sender())
self._receiver_task = self.create_task(self._ws_receiver())
done, pending = await asyncio.wait(
{self._sender_task, self._receiver_task},
return_when=asyncio.FIRST_EXCEPTION,
)
for t in pending:
t.cancel()
except asyncio.CancelledError:
logger.debug("OpenAIRealtimeASRService websocket task cancelled")
raise
except Exception as e:
logger.error(f"OpenAIRealtimeASRService websocket error: {e}")
finally:
self._ws = None
self._sender_task = None
self._receiver_task = None
async def _configure_session(self):
if not self._ws:
return
payload: dict = {
"type": "session.update",
"session": {
"input_audio_format": "pcm16",
# Server-side VAD and transcription
"turn_detection": {
"type": "server_vad",
"create_response": False,
}
if self._server_vad
else None,
"input_audio_transcription": {
"model": self._stt_model,
},
},
}
# Inject VAD knob values only when provided
if self._server_vad and isinstance(payload.get("session", {}).get("turn_detection"), dict):
vad_cfg = payload["session"]["turn_detection"]
if self._vad_threshold is not None:
vad_cfg["threshold"] = float(self._vad_threshold)
if self._vad_prefix_padding_ms is not None:
vad_cfg["prefix_padding_ms"] = int(self._vad_prefix_padding_ms)
if self._vad_silence_duration_ms is not None:
vad_cfg["silence_duration_ms"] = int(self._vad_silence_duration_ms)
await self._ws.send(json.dumps(payload))
@traced(attachment_strategy=AttachmentStrategy.NONE, name="asr")
async def _ws_sender(self):
"""Reads PCM16 bytes from queue and sends append events to the server."""
try:
while True:
chunk: bytes = await self._audio_queue.get()
if not self._ws:
continue
try:
b64 = base64.b64encode(chunk).decode("ascii")
await self._ws.send(
json.dumps({
"type": "input_audio_buffer.append",
"audio": b64,
})
)
except Exception as e:
logger.warning(f"Failed to send audio chunk: {e}")
except asyncio.CancelledError:
raise
async def _push_interim(self, text: str):
if not text:
return
# Debounce: only push if changed
if self._partial_buffer.rstrip() == text.rstrip():
return
self._partial_buffer = text
await self.push_frame(
InterimTranscriptionFrame(text, "", time_now_iso8601(), None)
)
async def _push_final(self, text: str):
if not text:
return
self._partial_buffer = ""
await self.push_frame(
TranscriptionFrame(text, "", time_now_iso8601(), None)
)
async def _ws_receiver(self):
"""Receives transcription events and emits frames."""
try:
async for raw in self._ws:
try:
event = json.loads(raw)
except Exception:
continue
etype = event.get("type", "")
if etype == "conversation.item.input_audio_transcription.delta":
delta = event.get("delta", "")
if delta:
# Compose into cumulative partial text
text = (self._partial_buffer + delta) if self._partial_buffer else delta
await self._push_interim(text)
elif etype == "conversation.item.input_audio_transcription.completed":
text = event.get("transcript", "")
await self._push_final(text)
elif etype == "error":
logger.error(f"OpenAI Realtime error: {event}")
elif etype == "input_audio_buffer.speech_started":
try:
await self.push_frame(UserStartedSpeakingFrame())
except Exception:
pass
elif etype == "input_audio_buffer.committed":
try:
await self.push_frame(UserStoppedSpeakingFrame())
except Exception:
pass
except asyncio.CancelledError:
raise
class OpenAIRealtimeTTSService(TTSService):
"""Stub for OpenAI Realtime TTS service."""
def __init__(self, **kwargs): # noqa: D401
super().__init__(**kwargs)
logger.warning("OpenAIRealtimeTTSService is a stub and not wired to Realtime yet.")
def can_generate_metrics(self) -> bool:
return False
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: # pragma: no cover - stub
raise NotImplementedError("OpenAIRealtimeTTSService Realtime path not implemented.")
@traceable
class OpenAITTSService(TTSService):
"""OpenAI Text-to-Speech over REST streaming.
Uses OpenAI's `audio.speech` streaming API to synthesize PCM audio.
"""
OPENAI_SAMPLE_RATE = 24000
_VALID_VOICES = {
"alloy",
"ash",
"ballad",
"coral",
"echo",
"fable",
"onyx",
"nova",
"sage",
"shimmer",
"verse",
}
def __init__(
self,
*,
api_key: str | None = None,
base_url: str | None = None,
voice: str = "shimmer",
model: str = "gpt-4o-mini-tts",
sample_rate: int | None = None,
instructions: str | None = None,
speed: float | None = None,
**kwargs,
):
# Default to native 24 kHz to avoid quality loss from resampling unless caller overrides
sr = sample_rate or self.OPENAI_SAMPLE_RATE
if sr != self.OPENAI_SAMPLE_RATE:
logger.info(
f"OpenAI TTS renders at {self.OPENAI_SAMPLE_RATE} Hz; resampling to {sr} Hz for output."
)
super().__init__(sample_rate=sr, **kwargs)
# Client is created lazily to avoid import errors if openai not installed at import-time
self._api_key = api_key or _get_env("OPENAI_API_KEY")
self._base_url = base_url or _get_env("OPENAI_BASE_URL")
self._client = None
# Allow env overrides for voice/speed
env_voice = _get_env("OPENAI_TTS_VOICE")
selected_voice = env_voice if env_voice else voice
env_speed = _get_env("OPENAI_TTS_SPEED")
try:
selected_speed = float(env_speed) if env_speed is not None else speed
except Exception:
selected_speed = speed
self._speed = selected_speed
self.set_model_name(model)
# Validate voice
if selected_voice not in self._VALID_VOICES:
logger.warning(f"Unknown OpenAI TTS voice '{selected_voice}', defaulting to 'alloy'")
selected_voice = "alloy"
self.set_voice(selected_voice)
self._instructions = instructions
self._ratecv_state = None
# Carry a leftover byte to ensure PCM16 sample alignment across frames
self._pcm_carry: bytes = b""
def _ensure_client(self):
if self._client is None:
try:
from openai import AsyncOpenAI # type: ignore
except Exception as e: # noqa: BLE001
raise RuntimeError(
"openai package is required for OpenAITTSService. Install 'openai'"
) from e
self._client = AsyncOpenAI(api_key=self._api_key, base_url=self._base_url)
def can_generate_metrics(self) -> bool:
return True
async def start(self, frame: StartFrame):
await super().start(frame)
# No warning: we resample to the configured sample rate
@traced(attachment_strategy=AttachmentStrategy.NONE, name="tts")
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"{self}: Generating TTS [{text}]")
self._ensure_client()
try:
await self.start_ttfb_metrics()
create_params = {
"input": text,
"model": self.model_name,
"voice": self._voice_id,
"response_format": "pcm",
}
if self._instructions:
create_params["instructions"] = self._instructions
if self._speed is not None:
create_params["speed"] = self._speed
# Stream response
async with self._client.audio.speech.with_streaming_response.create(**create_params) as r:
if getattr(r, "status_code", 200) != 200:
error_text = await r.text()
logger.error(f"{self} error getting audio (status: {r.status_code}, error: {error_text})")
yield ErrorFrame(f"OpenAI TTS error (status: {r.status_code}, error: {error_text})")
return
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame()
first_chunk = True
chunk_size = getattr(self, "chunk_size", 4096)
async for chunk in r.iter_bytes(chunk_size):
if not chunk:
continue
# Ensure 16-bit sample alignment by carrying last odd byte to next chunk
if self._pcm_carry:
chunk = self._pcm_carry + chunk
self._pcm_carry = b""
if len(chunk) % 2 == 1:
self._pcm_carry = chunk[-1:]
chunk = chunk[:-1]
if first_chunk:
await self.stop_ttfb_metrics()
first_chunk = False
out_bytes = chunk
if self.sample_rate != self.OPENAI_SAMPLE_RATE:
try:
out_bytes, self._ratecv_state = audioop.ratecv(
chunk,
2,
1,
self.OPENAI_SAMPLE_RATE,
self.sample_rate,
self._ratecv_state,
)
except Exception as e:
logger.warning(f"TTS resample failed, passing through 24kHz: {e}")
out_bytes = chunk
yield TTSAudioRawFrame(out_bytes, self.sample_rate, 1)
yield TTSStoppedFrame()
except Exception as e: # noqa: BLE001
logger.exception(f"{self} error generating TTS: {e}")
# Yield an error frame so downstream can react
try:
yield ErrorFrame(str(e))
except Exception:
# If yielding after exception fails, just swallow
pass