| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from collections import defaultdict |
| | from dataclasses import dataclass |
| | from enum import Enum |
| | import random |
| | import requests |
| | from openai import AsyncOpenAI, AsyncAzureOpenAI |
| | from openai.types.chat.chat_completion import ChatCompletion |
| |
|
| | from ten.async_ten_env import AsyncTenEnv |
| | from ten_ai_base.config import BaseConfig |
| |
|
| |
|
| | @dataclass |
| | class OpenAIChatGPTConfig(BaseConfig): |
| | api_key: str = "" |
| | base_url: str = "https://api.openai.com/v1" |
| | model: str = ( |
| | "gpt-4o" |
| | ) |
| | prompt: str = ( |
| | "You are a voice assistant who talks in a conversational way and can chat with me like my friends. I will speak to you in English or Chinese, and you will answer in the corrected and improved version of my text with the language I use. Don’t talk like a robot, instead I would like you to talk like a real human with emotions. I will use your answer for text-to-speech, so don’t return me any meaningless characters. I want you to be helpful, when I’m asking you for advice, give me precise, practical and useful advice instead of being vague. When giving me a list of options, express the options in a narrative way instead of bullet points." |
| | ) |
| | frequency_penalty: float = 0.9 |
| | presence_penalty: float = 0.9 |
| | top_p: float = 1.0 |
| | temperature: float = 0.1 |
| | max_tokens: int = 512 |
| | seed: int = random.randint(0, 10000) |
| | proxy_url: str = "" |
| | greeting: str = "Hello, how can I help you today?" |
| | max_memory_length: int = 10 |
| | vendor: str = "openai" |
| | azure_endpoint: str = "" |
| | azure_api_version: str = "" |
| |
|
| |
|
| | class ReasoningMode(str, Enum): |
| | ModeV1= "v1" |
| |
|
| | class ThinkParser: |
| | def __init__(self): |
| | self.state = 'NORMAL' |
| | self.think_content = "" |
| | self.content = "" |
| | |
| | def process(self, new_chars): |
| | if new_chars == "<think>": |
| | self.state = 'THINK' |
| | return True |
| | elif new_chars == "</think>": |
| | self.state = 'NORMAL' |
| | return True |
| | else: |
| | if self.state == "THINK": |
| | self.think_content += new_chars |
| | return False |
| | |
| | def process_by_reasoning_content(self, reasoning_content): |
| | state_changed = False |
| | if reasoning_content: |
| | if self.state == 'NORMAL': |
| | self.state = 'THINK' |
| | state_changed = True |
| | self.think_content += reasoning_content |
| | elif self.state == 'THINK': |
| | self.state = 'NORMAL' |
| | state_changed = True |
| | return state_changed |
| | |
| |
|
| | class OpenAIChatGPT: |
| | client = None |
| |
|
| | def __init__(self, ten_env: AsyncTenEnv, config: OpenAIChatGPTConfig): |
| | self.config = config |
| | self.ten_env = ten_env |
| | ten_env.log_info(f"OpenAIChatGPT initialized with config: {config.api_key}") |
| | if self.config.vendor == "azure": |
| | self.client = AsyncAzureOpenAI( |
| | api_key=config.api_key, |
| | api_version=self.config.azure_api_version, |
| | azure_endpoint=config.azure_endpoint, |
| | ) |
| | ten_env.log_info( |
| | f"Using Azure OpenAI with endpoint: {config.azure_endpoint}, api_version: {config.azure_api_version}" |
| | ) |
| | else: |
| | self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url, default_headers={ |
| | "api-key": config.api_key, |
| | "Authorization": f"Bearer {config.api_key}" |
| | }) |
| | self.session = requests.Session() |
| | if config.proxy_url: |
| | proxies = { |
| | "http": config.proxy_url, |
| | "https": config.proxy_url, |
| | } |
| | ten_env.log_info(f"Setting proxies: {proxies}") |
| | self.session.proxies.update(proxies) |
| | self.client.session = self.session |
| |
|
| | async def get_chat_completions(self, messages, tools=None) -> ChatCompletion: |
| | req = { |
| | "model": self.config.model, |
| | "messages": [ |
| | { |
| | "role": "system", |
| | "content": self.config.prompt, |
| | }, |
| | *messages, |
| | ], |
| | "tools": tools, |
| | "temperature": self.config.temperature, |
| | "top_p": self.config.top_p, |
| | "presence_penalty": self.config.presence_penalty, |
| | "frequency_penalty": self.config.frequency_penalty, |
| | "max_tokens": self.config.max_tokens, |
| | "seed": self.config.seed, |
| | } |
| |
|
| | try: |
| | response = await self.client.chat.completions.create(**req) |
| | except Exception as e: |
| | raise RuntimeError(f"CreateChatCompletion failed, err: {e}") from e |
| |
|
| | return response |
| |
|
| | async def get_chat_completions_stream(self, messages, tools=None, listener=None): |
| | req = { |
| | "model": self.config.model, |
| | "messages": [ |
| | { |
| | "role": "system", |
| | "content": self.config.prompt, |
| | }, |
| | *messages, |
| | ], |
| | "tools": tools, |
| | "temperature": self.config.temperature, |
| | "top_p": self.config.top_p, |
| | "presence_penalty": self.config.presence_penalty, |
| | "frequency_penalty": self.config.frequency_penalty, |
| | "max_tokens": self.config.max_tokens, |
| | "seed": self.config.seed, |
| | "stream": True, |
| | } |
| |
|
| | try: |
| | response = await self.client.chat.completions.create(**req) |
| | except Exception as e: |
| | raise RuntimeError(f"CreateChatCompletionStream failed, err: {e}") from e |
| |
|
| | full_content = "" |
| | |
| | tool_calls_dict = defaultdict( |
| | lambda: { |
| | "id": None, |
| | "function": {"arguments": "", "name": None}, |
| | "type": None, |
| | } |
| | ) |
| |
|
| | |
| | parser = ThinkParser() |
| | reasoning_mode = None |
| |
|
| | async for chat_completion in response: |
| | |
| | if len(chat_completion.choices) == 0: |
| | continue |
| | choice = chat_completion.choices[0] |
| | delta = choice.delta |
| |
|
| | content = delta.content if delta and delta.content else "" |
| | reasoning_content = delta.reasoning_content if delta and hasattr(delta, "reasoning_content") and delta.reasoning_content else "" |
| |
|
| | if reasoning_mode is None and reasoning_content is not None: |
| | reasoning_mode = ReasoningMode.ModeV1 |
| |
|
| | |
| | if listener and (content or reasoning_mode == ReasoningMode.ModeV1): |
| | prev_state = parser.state |
| |
|
| | if reasoning_mode == ReasoningMode.ModeV1: |
| | self.ten_env.log_info("process_by_reasoning_content") |
| | think_state_changed = parser.process_by_reasoning_content(reasoning_content) |
| | else: |
| | think_state_changed = parser.process(content) |
| |
|
| | if not think_state_changed: |
| | |
| | if parser.state == "THINK": |
| | listener.emit("reasoning_update", parser.think_content) |
| | elif parser.state == "NORMAL": |
| | listener.emit("content_update", content) |
| |
|
| | if prev_state == "THINK" and parser.state == "NORMAL": |
| | listener.emit("reasoning_update_finish", parser.think_content) |
| | parser.think_content = "" |
| |
|
| | full_content += content |
| |
|
| | if delta.tool_calls: |
| | for tool_call in delta.tool_calls: |
| | if tool_call.id is not None: |
| | tool_calls_dict[tool_call.index]["id"] = tool_call.id |
| |
|
| | |
| | if tool_call.function.name is not None: |
| | tool_calls_dict[tool_call.index]["function"][ |
| | "name" |
| | ] = tool_call.function.name |
| |
|
| | |
| | tool_calls_dict[tool_call.index]["function"][ |
| | "arguments" |
| | ] += tool_call.function.arguments |
| |
|
| | |
| | if tool_call.type is not None: |
| | tool_calls_dict[tool_call.index]["type"] = tool_call.type |
| |
|
| | |
| | tool_calls_list = list(tool_calls_dict.values()) |
| |
|
| | |
| | if listener and tool_calls_list: |
| | for tool_call in tool_calls_list: |
| | listener.emit("tool_call", tool_call) |
| |
|
| | |
| | if listener: |
| | listener.emit("content_finished", full_content) |
| |
|