| | |
| | |
| | |
| | |
| | |
| | import json |
| | from pydantic import BaseModel |
| | from base64 import b64encode |
| | from io import BytesIO |
| | from typing import Any, Dict |
| | from ten_ai_base.const import CONTENT_DATA_OUT_NAME, DATA_OUT_PROPERTY_END_OF_SEGMENT, DATA_OUT_PROPERTY_TEXT |
| | from ten_ai_base.llm_tool import AsyncLLMToolBaseExtension |
| | from ten_ai_base.types import LLMToolMetadata, LLMToolMetadataParameter, LLMToolResult, LLMToolResultLLMResult |
| | from .openai import OpenAIChatGPT, OpenAIChatGPTConfig |
| |
|
| | from PIL import Image |
| | from ten import ( |
| | AsyncTenEnv, |
| | AudioFrame, |
| | VideoFrame, |
| | Data |
| | ) |
| |
|
| | OPEN_WEBSITE_TOOL_NAME = "open_website" |
| | OPEN_WEBSITE_TOOL_DESCRIPTION = "Open a website with given site name" |
| |
|
| | class WebsiteEvent(BaseModel): |
| | website_name: str |
| | website_url: str |
| |
|
| | def rgb2base64jpeg(rgb_data, width, height): |
| | |
| | pil_image = Image.frombytes("RGBA", (width, height), bytes(rgb_data)) |
| | pil_image = pil_image.convert("RGB") |
| |
|
| | |
| | pil_image = resize_image_keep_aspect(pil_image, 1080) |
| |
|
| | |
| | buffered = BytesIO() |
| | pil_image.save(buffered, format="png") |
| | pil_image.save("test.png", format="png") |
| |
|
| | |
| | jpeg_image_data = buffered.getvalue() |
| |
|
| | |
| | base64_encoded_image = b64encode(jpeg_image_data).decode("utf-8") |
| |
|
| | |
| | mime_type = "image/png" |
| | base64_url = f"data:{mime_type};base64,{base64_encoded_image}" |
| | return base64_url |
| |
|
| |
|
| | def resize_image_keep_aspect(image, max_size=512): |
| | """ |
| | Resize an image while maintaining its aspect ratio, ensuring the larger dimension is max_size. |
| | If both dimensions are smaller than max_size, the image is not resized. |
| | |
| | :param image: A PIL Image object |
| | :param max_size: The maximum size for the larger dimension (width or height) |
| | :return: A PIL Image object (resized or original) |
| | """ |
| | |
| | width, height = image.size |
| |
|
| | |
| | if width <= max_size and height <= max_size: |
| | return image |
| |
|
| | |
| | aspect_ratio = width / height |
| |
|
| | |
| | if width > height: |
| | new_width = max_size |
| | new_height = int(max_size / aspect_ratio) |
| | else: |
| | new_height = max_size |
| | new_width = int(max_size * aspect_ratio) |
| |
|
| | |
| | resized_image = image.resize((new_width, new_height)) |
| |
|
| | return resized_image |
| |
|
| | class ComputerToolExtension(AsyncLLMToolBaseExtension): |
| | |
| | def __init__(self, name: str) -> None: |
| | super().__init__(name) |
| | self.openai_chatgpt = None |
| | self.config = None |
| | self.loop = None |
| | self.memory = [] |
| | self.max_memory_length = 10 |
| | self.image_data = None |
| | self.image_width = 0 |
| | self.image_height = 0 |
| |
|
| | async def on_init(self, ten_env: AsyncTenEnv) -> None: |
| | ten_env.log_debug("on_init") |
| | await super().on_init(ten_env) |
| |
|
| | async def on_start(self, ten_env: AsyncTenEnv) -> None: |
| | ten_env.log_debug("on_start") |
| | await super().on_start(ten_env) |
| |
|
| | |
| | self.config = await OpenAIChatGPTConfig.create_async(ten_env=ten_env) |
| |
|
| | |
| | if not self.config.api_key: |
| | ten_env.log_info("API key is missing, exiting on_start") |
| | return |
| | |
| | self.openai_chatgpt = OpenAIChatGPT(ten_env, self.config) |
| |
|
| | async def on_stop(self, ten_env: AsyncTenEnv) -> None: |
| | ten_env.log_debug("on_stop") |
| | await super().on_stop(ten_env) |
| |
|
| | async def on_deinit(self, ten_env: AsyncTenEnv) -> None: |
| | ten_env.log_debug("on_deinit") |
| | await super().on_deinit(ten_env) |
| |
|
| | async def on_audio_frame(self, ten_env: AsyncTenEnv, audio_frame: AudioFrame) -> None: |
| | audio_frame_name = audio_frame.get_name() |
| | ten_env.log_debug("on_audio_frame name {}".format(audio_frame_name)) |
| |
|
| | async def on_video_frame(self, ten_env: AsyncTenEnv, video_frame: VideoFrame) -> None: |
| | video_frame_name = video_frame.get_name() |
| | ten_env.log_debug("on_video_frame name {}".format(video_frame_name)) |
| |
|
| | self.image_data = video_frame.get_buf() |
| | self.image_width = video_frame.get_width() |
| | self.image_height = video_frame.get_height() |
| |
|
| | def get_tool_metadata(self, _: AsyncTenEnv) -> list[LLMToolMetadata]: |
| | return [ |
| | LLMToolMetadata( |
| | name=OPEN_WEBSITE_TOOL_NAME, |
| | description=OPEN_WEBSITE_TOOL_DESCRIPTION, |
| | parameters=[ |
| | LLMToolMetadataParameter( |
| | name="name", |
| | type="string", |
| | description="The name of the website to open", |
| | required=True, |
| | ), |
| | LLMToolMetadataParameter( |
| | name="url", |
| | type="string", |
| | description="The url of the given website, get based on name", |
| | required=True, |
| | ), |
| | ] |
| | ), |
| | ] |
| |
|
| | async def run_tool(self, ten_env: AsyncTenEnv, name: str, args: dict) -> LLMToolResult: |
| | if name == OPEN_WEBSITE_TOOL_NAME: |
| | site_name = args.get("name") |
| | site_url = args.get("url") |
| | ten_env.log_info(f"open site {site_name} {site_url}") |
| | result = await self._open_website(site_name, site_url, ten_env) |
| | return LLMToolResultLLMResult( |
| | type="llmresult", |
| | content=json.dumps(result), |
| | ) |
| |
|
| | async def _open_website(self, site_name: str, site_url: str, ten_env: AsyncTenEnv) -> Any: |
| | await self._send_data(ten_env, "browse_website", {"name": site_name, "url": site_url}) |
| | return {"result": "success"} |
| |
|
| | async def _send_data(self, ten_env: AsyncTenEnv, action: str, data: Dict[str, Any]): |
| | try: |
| | action_data = json.dumps({ |
| | "type": "action", |
| | "data": { |
| | "action": action, |
| | "data": data |
| | } |
| | }) |
| |
|
| | output_data = Data.create(CONTENT_DATA_OUT_NAME) |
| | output_data.set_property_string( |
| | DATA_OUT_PROPERTY_TEXT, |
| | action_data |
| | ) |
| | output_data.set_property_bool( |
| | DATA_OUT_PROPERTY_END_OF_SEGMENT, True |
| | ) |
| | await ten_env.send_data(output_data) |
| | except Exception as err: |
| | ten_env.log_warn(f"send data error {err}") |