from abc import ABC, abstractmethod from typing import Dict, Union, get_origin, get_args from pydantic import BaseModel, Field from types import UnionType import logging log = logging.getLogger(__name__) from src.vectorstore import VectorStore from omegaconf import OmegaConf class ToolBase(BaseModel, ABC): @abstractmethod def invoke(cls, input: Dict): pass @classmethod def to_openai_tool(cls): """ Extracts function metadata from a Pydantic class, including function name, parameters, and descriptions. Formats it into a structure similar to OpenAI's function metadata. """ function_metadata = { "type": "function", "function": { "name": cls.__name__, # Function name is same as the class name, in lowercase "description": cls.__doc__.strip(), "parameters": { "type": "object", "properties": {}, "required": [], }, }, } # Iterate over the fields to add them to the parameters for field_name, field_info in cls.model_fields.items(): # Field properties field_type = "string" # Default to string, will adjust if it's a different type annotation = field_info.annotation.__args__[0] if getattr(field_info.annotation, "__origin__", None) is Union else field_info.annotation has_none = False if get_origin(annotation) is UnionType: # Check if it's a Union type args = get_args(annotation) if type(None) in args: has_none = True args = [arg for arg in args if type(None) != arg] if len(args) > 1: raise TypeError("It can be union of only a valid type (str, int, bool, etc) and None") elif len(args) == 0: raise TypeError("There must be a valid type (str, int, bool, etc) not only None") else: annotation = args[0] if annotation == int: field_type = "integer" elif annotation == bool: field_type = "boolean" # Add the field's description and type to the properties function_metadata["function"]["parameters"]["properties"][field_name] = { "type": field_type, "description": field_info.description, } # Determine if the field is required (not Optional or None) if field_info.is_required(): function_metadata["function"]["parameters"]["required"].append(field_name) has_none = True # If there's an enum (like for `unit`), add it to the properties if hasattr(field_info, 'default') and field_info.default is not None and isinstance(field_info.default, list): function_metadata["function"]["parameters"]["properties"][field_name]["enum"] = field_info.default if not has_none: function_metadata["function"]["parameters"]["required"].append(field_name) return function_metadata # Load the configuration file # =========================================================================== config_file = "config.yaml" cfg = OmegaConf.load(config_file) # Initialize VectorStore, tools and oitools # =========================================================================== vdb = VectorStore(**cfg.vdb) tools: Dict[str, ToolBase] = {} oitools = [] def tool_register(cls: BaseModel): oaitool = cls.to_openai_tool() oitools.append(oaitool) tools[oaitool["function"]["name"]] = cls @tool_register class retrieve_aina_data(ToolBase): """Retrieves relevant information from Aina Challenge vectorstore, based on the user's query.""" log.info("@tool_register: retrieve_aina_data()") query: str = Field(description="The user's input or question, used to search in Aina Challenge vectorstore.") log.info(f"query: {query}") @classmethod def invoke(cls, input: Dict) -> str: log.info(f"retrieve_aina_data.invoke() input: {input}") # Check if the input is a dictionary query = input.get("query", None) if not query: return "Missing required argument: query." return vdb.get_context(query)