import os import httpx from fastapi import HTTPException from aiortc import RTCPeerConnection, RTCSessionDescription from openai import AsyncOpenAI from src.config import logger from ._openai_tools import _OPENAI_TOOLS class OpenAIClient: def __init__(self): self.client = None self.model = os.getenv("OPENAI_REALTIME_MODEL") self.session_url = os.getenv("OPENAI_BASE_URL") + "/realtime/sessions" self.webrtc_url = ( os.getenv("OPENAI_BASE_URL") + "/realtime?model={self.model}" ) self.modalities = ["text", "audio"] self.voice = "alloy" self.transcription_model = "whisper-1" self.temperature = 0.8 self.max_response_output_tokens = 1000 self.system_prompt_text = None self.system_prompt_voice = None self.tools = _OPENAI_TOOLS self.tool_choice = "auto" async def __aenter__(self): self.system_prompt_text = await self.prompt_loader("system_prompt_text.md") self.system_prompt_voice = await self.prompt_loader("system_prompt_voice.md") self.client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) return self async def __aexit__(self, *args): pass async def prompt_loader(self, file_name: str): path = os.path.join("src", "prompts", file_name) with open(path, mode="r", encoding="utf-8") as file: prompt = file.read() return prompt async def text_generation(self, query: str): completion = await self.client.chat.completions.create( model=os.getenv("OPENAI_CHAT_COMPLETION_MODEL"), messages=[ { "role": "user", "content": query, }, ], ) return completion.choices[0].message.content async def create_text_session_event(self): return { "type": "session.update", "session": { "modalities": ["text"], "instructions": self.system_prompt_text, "tools": self.tools, "tool_choice": self.tool_choice, "temperature": self.temperature, "max_response_output_tokens": self.max_response_output_tokens, }, } async def create_openai_session(self, text_mode_only=False): headers = { "Authorization": f'Bearer {os.getenv("OPENAI_API_KEY")}', "Content-Type": "application/json", } payload = { "model": self.model, "instructions": self.system_prompt_voice, "tools": self.tools, "temperature": self.temperature, "max_response_output_tokens": self.max_response_output_tokens, "voice": self.voice, "modalities": self.modalities if not text_mode_only else ["text"], "input_audio_transcription": ( {"model": self.transcription_model} if not text_mode_only else None ), } async with httpx.AsyncClient() as client: response = await client.post( self.session_url, json=payload, headers=headers ) response.raise_for_status() return response.json() async def create_webrtc_connection(self, ephemeral_key, offer): peer_connection = RTCPeerConnection() offer_desc = RTCSessionDescription(sdp=offer["sdp"], type=offer["type"]) await peer_connection.setRemoteDescription(offer_desc) async with httpx.AsyncClient() as client: headers = { "Authorization": f"Bearer {ephemeral_key}", "Content-Type": "application/sdp", } response = await client.post( self.webrtc_url, content=offer_desc.sdp, headers=headers ) if not response.is_success: error_data = response.json() error_message = error_data.get("error", {}).get( "message", response.reason_phrase ) raise HTTPException( status_code=response.status_code, detail=error_message ) answer_sdp = response.text return { "answer": {"sdp": answer_sdp, "type": "answer"}, "ephemeral_key": ephemeral_key, }