| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import json |
| | import aiohttp |
| |
|
| | from typing import Any |
| | from dataclasses import dataclass |
| |
|
| | from ten import Cmd |
| |
|
| | from ten.async_ten_env import AsyncTenEnv |
| | from ten_ai_base.config import BaseConfig |
| | from ten_ai_base.types import LLMToolMetadata, LLMToolMetadataParameter, LLMToolResult, LLMToolResultLLMResult |
| | from ten_ai_base.llm_tool import AsyncLLMToolBaseExtension |
| |
|
| | CMD_TOOL_REGISTER = "tool_register" |
| | CMD_TOOL_CALL = "tool_call" |
| | CMD_PROPERTY_NAME = "name" |
| | CMD_PROPERTY_ARGS = "args" |
| |
|
| | TOOL_REGISTER_PROPERTY_NAME = "name" |
| | TOOL_REGISTER_PROPERTY_DESCRIPTON = "description" |
| | TOOL_REGISTER_PROPERTY_PARAMETERS = "parameters" |
| | TOOL_CALLBACK = "callback" |
| |
|
| | CURRENT_TOOL_NAME = "get_current_weather" |
| | CURRENT_TOOL_DESCRIPTION = "Determine current weather in user's location." |
| | CURRENT_TOOL_PARAMETERS = { |
| | "type": "object", |
| | "properties": { |
| | "location": { |
| | "type": "string", |
| | "description": "The city and state (use only English) e.g. San Francisco, CA", |
| | } |
| | }, |
| | "required": ["location"], |
| | } |
| |
|
| | |
| | HISTORY_TOOL_NAME = "get_past_weather" |
| | HISTORY_TOOL_DESCRIPTION = "Determine weather within past 7 days in user's location." |
| | HISTORY_TOOL_PARAMETERS = { |
| | "type": "object", |
| | "properties": { |
| | "location": { |
| | "type": "string", |
| | "description": "The city and state (use only English) e.g. San Francisco, CA", |
| | }, |
| | "datetime": { |
| | "type": "string", |
| | "description": "The datetime user is referring in date format e.g. 2024-10-09", |
| | }, |
| | }, |
| | "required": ["location", "datetime"], |
| | } |
| |
|
| | |
| | FORECAST_TOOL_NAME = "get_future_weather" |
| | FORECAST_TOOL_DESCRIPTION = "Determine weather in next 3 days in user's location." |
| | FORECAST_TOOL_PARAMETERS = { |
| | "type": "object", |
| | "properties": { |
| | "location": { |
| | "type": "string", |
| | "description": "The city and state (use only English) e.g. San Francisco, CA", |
| | } |
| | }, |
| | "required": ["location"], |
| | } |
| |
|
| | PROPERTY_API_KEY = "api_key" |
| |
|
| |
|
| | @dataclass |
| | class WeatherToolConfig(BaseConfig): |
| | api_key: str = "" |
| |
|
| |
|
| | class WeatherToolExtension(AsyncLLMToolBaseExtension): |
| | def __init__(self, name: str) -> None: |
| | super().__init__(name) |
| | self.session = None |
| | self.ten_env = None |
| | self.config: WeatherToolConfig = None |
| |
|
| | async def on_init(self, ten_env: AsyncTenEnv) -> None: |
| | ten_env.log_debug("on_init") |
| | self.session = aiohttp.ClientSession() |
| |
|
| | async def on_start(self, ten_env: AsyncTenEnv) -> None: |
| | ten_env.log_debug("on_start") |
| |
|
| | self.config = await WeatherToolConfig.create_async(ten_env=ten_env) |
| | ten_env.log_info(f"config: {self.config}") |
| | if self.config.api_key: |
| | await super().on_start(ten_env) |
| |
|
| | self.ten_env = ten_env |
| |
|
| | async def on_stop(self, ten_env: AsyncTenEnv) -> None: |
| | ten_env.log_debug("on_stop") |
| |
|
| | |
| | if self.session: |
| | await self.session.close() |
| | self.session = None |
| |
|
| | async def on_deinit(self, ten_env: AsyncTenEnv) -> None: |
| | ten_env.log_debug("on_deinit") |
| |
|
| | async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None: |
| | cmd_name = cmd.get_name() |
| | ten_env.log_debug("on_cmd name {}".format(cmd_name)) |
| |
|
| | await super().on_cmd(ten_env, cmd) |
| |
|
| | def get_tool_metadata(self, ten_env: AsyncTenEnv) -> list[LLMToolMetadata]: |
| | return [ |
| | LLMToolMetadata( |
| | name=CURRENT_TOOL_NAME, |
| | description=CURRENT_TOOL_DESCRIPTION, |
| | parameters=[ |
| | LLMToolMetadataParameter( |
| | name="location", |
| | type="string", |
| | description="The city and state (use only English) e.g. San Francisco, CA", |
| | required=True, |
| | ), |
| | ], |
| | ), |
| | LLMToolMetadata( |
| | name=HISTORY_TOOL_NAME, |
| | description=HISTORY_TOOL_DESCRIPTION, |
| | parameters=[ |
| | LLMToolMetadataParameter( |
| | name="location", |
| | type="string", |
| | description="The city and state (use only English) e.g. San Francisco, CA", |
| | required=True, |
| | ), |
| | LLMToolMetadataParameter( |
| | name="datetime", |
| | type="string", |
| | description="The datetime user is referring in date format e.g. 2024-10-09", |
| | required=True, |
| | ), |
| | ], |
| | ), |
| | LLMToolMetadata( |
| | name=FORECAST_TOOL_NAME, |
| | description=FORECAST_TOOL_DESCRIPTION, |
| | parameters=[ |
| | LLMToolMetadataParameter( |
| | name="location", |
| | type="string", |
| | description="The city and state (use only English) e.g. San Francisco, CA", |
| | required=True, |
| | ), |
| | ], |
| | ), |
| | ] |
| |
|
| | async def run_tool( |
| | self, ten_env: AsyncTenEnv, name: str, args: dict |
| | ) -> LLMToolResult | None: |
| | ten_env.log_info(f"run_tool name: {name}, args: {args}") |
| | if name == CURRENT_TOOL_NAME: |
| | result = await self._get_current_weather(args) |
| | return LLMToolResultLLMResult( |
| | type="llmresult", |
| | content=json.dumps(result), |
| | ) |
| | elif name == HISTORY_TOOL_NAME: |
| | result = await self._get_past_weather(args) |
| | |
| | return LLMToolResultLLMResult( |
| | type="llmresult", |
| | content=json.dumps(result), |
| | ) |
| | elif name == FORECAST_TOOL_NAME: |
| | result = await self._get_future_weather(args) |
| | |
| | return LLMToolResultLLMResult( |
| | type="llmresult", |
| | content=json.dumps(result), |
| | ) |
| |
|
| | async def _get_current_weather(self, args: dict) -> Any: |
| | if "location" not in args: |
| | raise ValueError("Failed to get property") |
| |
|
| | try: |
| | location = args["location"] |
| | url = f"http://api.weatherapi.com/v1/current.json?key={self.config.api_key}&q={location}&aqi=no" |
| |
|
| | async with self.session.get(url) as response: |
| | result = await response.json() |
| | return { |
| | "location": result.get("location", {}).get("name", ""), |
| | "temperature": result.get("current", {}).get("temp_c", ""), |
| | "humidity": result.get("current", {}).get("humidity", ""), |
| | "wind_speed": result.get("current", {}).get("wind_kph", ""), |
| | } |
| | except Exception as e: |
| | self.ten_env.log_error(f"Failed to get current weather: {e}") |
| | return None |
| |
|
| | async def _get_past_weather(self, args: dict) -> Any: |
| | if "location" not in args or "datetime" not in args: |
| | raise ValueError("Failed to get property") |
| |
|
| | location = args["location"] |
| | datetime = args["datetime"] |
| | url = f"http://api.weatherapi.com/v1/history.json?key={self.config.api_key}&q={location}&dt={datetime}" |
| |
|
| | async with self.session.get(url) as response: |
| | result = await response.json() |
| |
|
| | |
| | if ( |
| | "forecast" in result |
| | and "forecastday" in result["forecast"] |
| | and result["forecast"]["forecastday"] |
| | ): |
| | result["forecast"]["forecastday"][0].pop("hour", None) |
| |
|
| | return result |
| |
|
| | async def _get_future_weather(self, args: dict) -> Any: |
| | if "location" not in args: |
| | raise ValueError("Failed to get property") |
| |
|
| | location = args["location"] |
| | url = f"http://api.weatherapi.com/v1/forecast.json?key={self.config.api_key}&q={location}&days=3&aqi=no&alerts=no" |
| |
|
| | async with self.session.get(url) as response: |
| | result = await response.json() |
| |
|
| | |
| | self.ten_env.log_info(f"get result {result}") |
| |
|
| | |
| | for d in result.get("forecast", {}).get("forecastday", []): |
| | d.pop("hour", None) |
| |
|
| | |
| | result.pop("current", None) |
| |
|
| | return result |
| |
|