File size: 4,419 Bytes
b2fc939
 
eb474ee
 
2efe331
eb474ee
b2fc939
a3aa6c1
b2fc939
 
 
 
2efe331
a3aa6c1
b97def2
a3aa6c1
b97def2
a3aa6c1
b2fc939
 
 
 
 
173ca46
 
b2fc939
d4f6849
b2fc939
 
173ca46
 
2efe331
b2fc939
 
 
 
 
6aa833d
 
 
 
 
 
2efe331
 
a3aa6c1
2efe331
 
 
 
 
 
 
 
 
d4f6849
 
 
 
 
173ca46
d4f6849
 
 
 
 
 
 
17f9ecb
b2fc939
 
 
 
 
 
173ca46
b2fc939
 
 
17f9ecb
 
 
 
 
b2fc939
 
 
 
 
 
 
 
eb474ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f2f6d9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import httpx
from fastapi import HTTPException
from aiortc import RTCPeerConnection, RTCSessionDescription
from openai import AsyncOpenAI

from src.config import logger
from ._openai_tools import _OPENAI_TOOLS


class OpenAIClient:
    def __init__(self):
        self.client = None
        self.model = os.getenv("OPENAI_REALTIME_MODEL")
        self.session_url = os.getenv("OPENAI_BASE_URL") + "/realtime/sessions"
        self.webrtc_url = (
            os.getenv("OPENAI_BASE_URL") + "/realtime?model={self.model}"
        )
        self.modalities = ["text", "audio"]
        self.voice = "alloy"
        self.transcription_model = "whisper-1"
        self.temperature = 0.8
        self.max_response_output_tokens = 1000
        self.system_prompt_text = None
        self.system_prompt_voice = None
        self.tools = _OPENAI_TOOLS
        self.tool_choice = "auto"

    async def __aenter__(self):
        self.system_prompt_text = await self.prompt_loader("system_prompt_text.md")
        self.system_prompt_voice = await self.prompt_loader("system_prompt_voice.md")
        self.client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        return self

    async def __aexit__(self, *args):
        pass

    async def prompt_loader(self, file_name: str):
        path = os.path.join("src", "prompts", file_name)
        with open(path, mode="r", encoding="utf-8") as file:
            prompt = file.read()
        return prompt

    async def text_generation(self, query: str):
        completion = await self.client.chat.completions.create(
            model=os.getenv("OPENAI_CHAT_COMPLETION_MODEL"),
            messages=[
                {
                    "role": "user",
                    "content": query,
                },
            ],
        )
        return completion.choices[0].message.content

    async def create_text_session_event(self):
        return {
            "type": "session.update",
            "session": {
                "modalities": ["text"],
                "instructions": self.system_prompt_text,
                "tools": self.tools,
                "tool_choice": self.tool_choice,
                "temperature": self.temperature,
                "max_response_output_tokens": self.max_response_output_tokens,
            },
        }

    async def create_openai_session(self, text_mode_only=False):
        headers = {
            "Authorization": f'Bearer {os.getenv("OPENAI_API_KEY")}',
            "Content-Type": "application/json",
        }
        payload = {
            "model": self.model,
            "instructions": self.system_prompt_voice,
            "tools": self.tools,
            "temperature": self.temperature,
            "max_response_output_tokens": self.max_response_output_tokens,
            "voice": self.voice,
            "modalities": self.modalities if not text_mode_only else ["text"],
            "input_audio_transcription": (
                {"model": self.transcription_model} if not text_mode_only else None
            ),
        }

        async with httpx.AsyncClient() as client:
            response = await client.post(
                self.session_url, json=payload, headers=headers
            )
            response.raise_for_status()
            return response.json()

    async def create_webrtc_connection(self, ephemeral_key, offer):
        peer_connection = RTCPeerConnection()
        offer_desc = RTCSessionDescription(sdp=offer["sdp"], type=offer["type"])
        await peer_connection.setRemoteDescription(offer_desc)

        async with httpx.AsyncClient() as client:
            headers = {
                "Authorization": f"Bearer {ephemeral_key}",
                "Content-Type": "application/sdp",
            }
            response = await client.post(
                self.webrtc_url, content=offer_desc.sdp, headers=headers
            )
            if not response.is_success:
                error_data = response.json()
                error_message = error_data.get("error", {}).get(
                    "message", response.reason_phrase
                )
                raise HTTPException(
                    status_code=response.status_code, detail=error_message
                )
            answer_sdp = response.text
            return {
                "answer": {"sdp": answer_sdp, "type": "answer"},
                "ephemeral_key": ephemeral_key,
            }