keepme-backend / src /utils /_openai_client.py
narinder1231's picture
use different prompts for voice and text conversations
173ca46
import os
import httpx
from fastapi import HTTPException
from aiortc import RTCPeerConnection, RTCSessionDescription
from openai import AsyncOpenAI
from src.config import logger
from ._openai_tools import _OPENAI_TOOLS
class OpenAIClient:
def __init__(self):
self.client = None
self.model = os.getenv("OPENAI_REALTIME_MODEL")
self.session_url = os.getenv("OPENAI_BASE_URL") + "/realtime/sessions"
self.webrtc_url = (
os.getenv("OPENAI_BASE_URL") + "/realtime?model={self.model}"
)
self.modalities = ["text", "audio"]
self.voice = "alloy"
self.transcription_model = "whisper-1"
self.temperature = 0.8
self.max_response_output_tokens = 1000
self.system_prompt_text = None
self.system_prompt_voice = None
self.tools = _OPENAI_TOOLS
self.tool_choice = "auto"
async def __aenter__(self):
self.system_prompt_text = await self.prompt_loader("system_prompt_text.md")
self.system_prompt_voice = await self.prompt_loader("system_prompt_voice.md")
self.client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
return self
async def __aexit__(self, *args):
pass
async def prompt_loader(self, file_name: str):
path = os.path.join("src", "prompts", file_name)
with open(path, mode="r", encoding="utf-8") as file:
prompt = file.read()
return prompt
async def text_generation(self, query: str):
completion = await self.client.chat.completions.create(
model=os.getenv("OPENAI_CHAT_COMPLETION_MODEL"),
messages=[
{
"role": "user",
"content": query,
},
],
)
return completion.choices[0].message.content
async def create_text_session_event(self):
return {
"type": "session.update",
"session": {
"modalities": ["text"],
"instructions": self.system_prompt_text,
"tools": self.tools,
"tool_choice": self.tool_choice,
"temperature": self.temperature,
"max_response_output_tokens": self.max_response_output_tokens,
},
}
async def create_openai_session(self, text_mode_only=False):
headers = {
"Authorization": f'Bearer {os.getenv("OPENAI_API_KEY")}',
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"instructions": self.system_prompt_voice,
"tools": self.tools,
"temperature": self.temperature,
"max_response_output_tokens": self.max_response_output_tokens,
"voice": self.voice,
"modalities": self.modalities if not text_mode_only else ["text"],
"input_audio_transcription": (
{"model": self.transcription_model} if not text_mode_only else None
),
}
async with httpx.AsyncClient() as client:
response = await client.post(
self.session_url, json=payload, headers=headers
)
response.raise_for_status()
return response.json()
async def create_webrtc_connection(self, ephemeral_key, offer):
peer_connection = RTCPeerConnection()
offer_desc = RTCSessionDescription(sdp=offer["sdp"], type=offer["type"])
await peer_connection.setRemoteDescription(offer_desc)
async with httpx.AsyncClient() as client:
headers = {
"Authorization": f"Bearer {ephemeral_key}",
"Content-Type": "application/sdp",
}
response = await client.post(
self.webrtc_url, content=offer_desc.sdp, headers=headers
)
if not response.is_success:
error_data = response.json()
error_message = error_data.get("error", {}).get(
"message", response.reason_phrase
)
raise HTTPException(
status_code=response.status_code, detail=error_message
)
answer_sdp = response.text
return {
"answer": {"sdp": answer_sdp, "type": "answer"},
"ephemeral_key": ephemeral_key,
}