| | import asyncio |
| | import base64 |
| | import json |
| | import os |
| | import aiohttp |
| |
|
| | from ten import AsyncTenEnv |
| |
|
| | from typing import Any, AsyncGenerator |
| | from .struct import InputAudioBufferAppend, ClientToServerMessage, ServerToClientMessage, parse_server_message, to_json |
| |
|
| | def smart_str(s: str, max_field_len: int = 128) -> str: |
| | """parse string as json, truncate data field to 128 characters, reserialize""" |
| | try: |
| | data = json.loads(s) |
| | if "delta" in data: |
| | key = "delta" |
| | elif "audio" in data: |
| | key = "audio" |
| | else: |
| | return s |
| |
|
| | if len(data[key]) > max_field_len: |
| | data[key] = data[key][:max_field_len] + "..." |
| | return json.dumps(data) |
| | except json.JSONDecodeError: |
| | return s |
| |
|
| |
|
| | class RealtimeApiConnection: |
| | def __init__( |
| | self, |
| | ten_env: AsyncTenEnv, |
| | base_uri: str, |
| | api_key: str | None = None, |
| | path: str = "/v1/realtime", |
| | verbose: bool = False |
| | ): |
| | self.ten_env = ten_env |
| | self.url = f"{base_uri}{path}" |
| | |
| | |
| |
|
| | self.api_key = api_key or os.environ.get("GLM_API_KEY") |
| | self.websocket: aiohttp.ClientWebSocketResponse | None = None |
| | self.verbose = verbose |
| | self.session = aiohttp.ClientSession() |
| |
|
| | async def __aenter__(self) -> "RealtimeApiConnection": |
| | await self.connect() |
| | return self |
| |
|
| | async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool: |
| | await self.close() |
| | return False |
| |
|
| | async def connect(self): |
| | headers = {} |
| |
|
| | headers = {"Authorization": "Bearer " + self.api_key} |
| |
|
| | self.websocket = await self.session.ws_connect( |
| | url=self.url, |
| | |
| | headers=headers, |
| | ) |
| |
|
| | async def send_audio_data(self, audio_data: bytes): |
| | """audio_data is assumed to be pcm16 24kHz mono little-endian""" |
| | base64_audio_data = base64.b64encode(audio_data).decode("utf-8") |
| | message = InputAudioBufferAppend(audio=base64_audio_data) |
| | await self.send_request(message) |
| |
|
| | async def send_request(self, message: ClientToServerMessage): |
| | assert self.websocket is not None |
| | message_str = to_json(message) |
| | if self.verbose: |
| | self.ten_env.log_info(f"-> {smart_str(message_str)}") |
| | await self.websocket.send_str(message_str) |
| |
|
| | async def listen(self) -> AsyncGenerator[ServerToClientMessage, None]: |
| | assert self.websocket is not None |
| | if self.verbose: |
| | self.ten_env.log_info("Listening for realtimeapi messages") |
| | try: |
| | async for msg in self.websocket: |
| | if msg.type == aiohttp.WSMsgType.TEXT: |
| | if self.verbose: |
| | self.ten_env.log_info(f"<- {smart_str(msg.data)}") |
| | yield self.handle_server_message(msg.data) |
| | elif msg.type == aiohttp.WSMsgType.ERROR: |
| | self.ten_env.log_error("Error during receive: %s", self.websocket.exception()) |
| | break |
| | except asyncio.CancelledError: |
| | self.ten_env.log_info("Receive messages task cancelled") |
| |
|
| | def handle_server_message(self, message: str) -> ServerToClientMessage: |
| | try: |
| | return parse_server_message(message) |
| | except Exception as e: |
| | self.ten_env.log_info(f"Error handling message {message} {e}") |
| |
|
| | async def close(self): |
| | |
| | if self.websocket: |
| | await self.websocket.close() |
| | self.websocket = None |
| |
|