File size: 4,658 Bytes
d519be4 cb3dcae d519be4 cb3dcae d519be4 cb3dcae d519be4 7e44889 cb3dcae d519be4 3f5217a d519be4 cb3dcae 7e44889 cb3dcae 7e44889 cb3dcae aa6ef3d 4701923 d519be4 7e44889 d519be4 cb34a9e 7e44889 cb3dcae d519be4 7e44889 d519be4 cb3dcae 7e44889 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
from abc import ABC, abstractmethod
from typing import Dict, Union, get_origin, get_args
from pydantic import BaseModel, Field
from types import UnionType
import os
import logging
from src.vectorstore import VectorStore
logger = logging.getLogger(__name__)
class ToolBase(BaseModel, ABC):
@abstractmethod
def invoke(cls, input: Dict):
pass
@classmethod
def to_openai_tool(cls):
"""
Extracts function metadata from a Pydantic class, including function name, parameters, and descriptions.
Formats it into a structure similar to OpenAI's function metadata.
"""
function_metadata = {
"type": "function",
"function": {
"name": cls.__name__, # Function name is same as the class name, in lowercase
"description": cls.__doc__.strip(),
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
}
# Iterate over the fields to add them to the parameters
for field_name, field_info in cls.model_fields.items():
# Field properties
field_type = "string" # Default to string, will adjust if it's a different type
annotation = field_info.annotation.__args__[0] if getattr(field_info.annotation, "__origin__", None) is Union else field_info.annotation
has_none = False
if get_origin(annotation) is UnionType: # Check if it's a Union type
args = get_args(annotation)
if type(None) in args:
has_none = True
args = [arg for arg in args if type(None) != arg]
if len(args) > 1:
raise TypeError("It can be union of only a valid type (str, int, bool, etc) and None")
elif len(args) == 0:
raise TypeError("There must be a valid type (str, int, bool, etc) not only None")
else:
annotation = args[0]
if annotation == int:
field_type = "integer"
elif annotation == bool:
field_type = "boolean"
# Add the field's description and type to the properties
function_metadata["function"]["parameters"]["properties"][field_name] = {
"type": field_type,
"description": field_info.description,
}
# Determine if the field is required (not Optional or None)
if field_info.is_required():
function_metadata["function"]["parameters"]["required"].append(field_name)
has_none = True
# If there's an enum (like for `unit`), add it to the properties
if hasattr(field_info, 'default') and field_info.default is not None and isinstance(field_info.default, list):
function_metadata["function"]["parameters"]["properties"][field_name]["enum"] = field_info.default
if not has_none:
function_metadata["function"]["parameters"]["required"].append(field_name)
return function_metadata
tools: Dict[str, ToolBase] = {}
oitools = []
vector_store = VectorStore(
embeddings_model=os.environ.get("EMBEDDINGS_MODEL"),
vs_local_path=os.environ.get("VS_LOCAL_PATH"),
vs_hf_path=os.environ.get("VS_HF_PATH"),
number_of_contexts=int(os.environ.get("RETRIEVE_NUM_CONTEXTS", 3))
)
def tool_register(cls: BaseModel):
"""Decorator to register a tool class and convert it to OpenAI tool format."""
logger.info(f"@tool_register: {cls.__name__}()")
oaitool = cls.to_openai_tool()
oitools.append(oaitool)
tools[oaitool["function"]["name"]] = cls
@tool_register
class retrieve_wiki_data(ToolBase):
"""Retrieves relevant information from wikipedia, based on the user's query."""
logger.info("@tool_register: retrieve_wiki_data()")
query: str = Field(description="The user's input or question, used to search Wikipedia.")
logger.info(f"query: {query}")
@classmethod
def invoke(cls, input: Dict) -> str:
# Check if the input is a dictionary
logger.info(f"retrieve_wiki_data.invoke() input: {input}")
query = input.get("query", None)
if not query:
return "Missing required argument: query."
# return "We are currently working on it. You can't use this tool right now—please try again later. Thank you for your patience!"
return vector_store.get_context(query)
|