|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Dict, Union, get_origin, get_args |
|
|
from pydantic import BaseModel, Field |
|
|
from types import UnionType |
|
|
import os |
|
|
import logging |
|
|
|
|
|
from src.vectorstore import VectorStore |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ToolBase(BaseModel, ABC): |
|
|
@abstractmethod |
|
|
def invoke(cls, input: Dict): |
|
|
pass |
|
|
|
|
|
@classmethod |
|
|
def to_openai_tool(cls): |
|
|
""" |
|
|
Extracts function metadata from a Pydantic class, including function name, parameters, and descriptions. |
|
|
Formats it into a structure similar to OpenAI's function metadata. |
|
|
""" |
|
|
function_metadata = { |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": cls.__name__, |
|
|
"description": cls.__doc__.strip(), |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": {}, |
|
|
"required": [], |
|
|
}, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
for field_name, field_info in cls.model_fields.items(): |
|
|
|
|
|
field_type = "string" |
|
|
annotation = field_info.annotation.__args__[0] if getattr(field_info.annotation, "__origin__", None) is Union else field_info.annotation |
|
|
|
|
|
has_none = False |
|
|
if get_origin(annotation) is UnionType: |
|
|
args = get_args(annotation) |
|
|
if type(None) in args: |
|
|
has_none = True |
|
|
args = [arg for arg in args if type(None) != arg] |
|
|
if len(args) > 1: |
|
|
raise TypeError("It can be union of only a valid type (str, int, bool, etc) and None") |
|
|
elif len(args) == 0: |
|
|
raise TypeError("There must be a valid type (str, int, bool, etc) not only None") |
|
|
else: |
|
|
annotation = args[0] |
|
|
|
|
|
if annotation == int: |
|
|
field_type = "integer" |
|
|
elif annotation == bool: |
|
|
field_type = "boolean" |
|
|
|
|
|
|
|
|
function_metadata["function"]["parameters"]["properties"][field_name] = { |
|
|
"type": field_type, |
|
|
"description": field_info.description, |
|
|
} |
|
|
|
|
|
|
|
|
if field_info.is_required(): |
|
|
function_metadata["function"]["parameters"]["required"].append(field_name) |
|
|
has_none = True |
|
|
|
|
|
|
|
|
if hasattr(field_info, 'default') and field_info.default is not None and isinstance(field_info.default, list): |
|
|
function_metadata["function"]["parameters"]["properties"][field_name]["enum"] = field_info.default |
|
|
if not has_none: |
|
|
function_metadata["function"]["parameters"]["required"].append(field_name) |
|
|
|
|
|
return function_metadata |
|
|
|
|
|
tools: Dict[str, ToolBase] = {} |
|
|
oitools = [] |
|
|
|
|
|
|
|
|
vector_store = VectorStore( |
|
|
embeddings_model=os.environ.get("EMBEDDINGS_MODEL"), |
|
|
vs_local_path=os.environ.get("VS_LOCAL_PATH"), |
|
|
vs_hf_path=os.environ.get("VS_HF_PATH"), |
|
|
number_of_contexts=int(os.environ.get("RETRIEVE_NUM_CONTEXTS", 3)) |
|
|
) |
|
|
|
|
|
|
|
|
def tool_register(cls: BaseModel): |
|
|
"""Decorator to register a tool class and convert it to OpenAI tool format.""" |
|
|
|
|
|
logger.info(f"@tool_register: {cls.__name__}()") |
|
|
oaitool = cls.to_openai_tool() |
|
|
oitools.append(oaitool) |
|
|
tools[oaitool["function"]["name"]] = cls |
|
|
|
|
|
|
|
|
@tool_register |
|
|
class retrieve_wiki_data(ToolBase): |
|
|
"""Retrieves relevant information from wikipedia, based on the user's query.""" |
|
|
|
|
|
logger.info("@tool_register: retrieve_wiki_data()") |
|
|
|
|
|
query: str = Field(description="The user's input or question, used to search Wikipedia.") |
|
|
logger.info(f"query: {query}") |
|
|
|
|
|
@classmethod |
|
|
def invoke(cls, input: Dict) -> str: |
|
|
|
|
|
|
|
|
logger.info(f"retrieve_wiki_data.invoke() input: {input}") |
|
|
|
|
|
query = input.get("query", None) |
|
|
if not query: |
|
|
return "Missing required argument: query." |
|
|
|
|
|
|
|
|
return vector_store.get_context(query) |
|
|
|
|
|
|
|
|
|