File size: 4,500 Bytes
570b60c
 
 
 
 
 
5b68ef9
 
570b60c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b68ef9
570b60c
 
5b68ef9
570b60c
 
 
5b68ef9
570b60c
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

from abc import ABC, abstractmethod
from typing import Dict, Union, get_origin, get_args
from pydantic import BaseModel, Field
from types import UnionType
import logging
log = logging.getLogger(__name__)

from src.vectorstore import VectorStore
from omegaconf import OmegaConf


class ToolBase(BaseModel, ABC):
    @abstractmethod
    def invoke(cls, input: Dict):
        pass

    @classmethod
    def to_openai_tool(cls):
        """
        Extracts function metadata from a Pydantic class, including function name, parameters, and descriptions.
        Formats it into a structure similar to OpenAI's function metadata.
        """
        function_metadata = {
            "type": "function",
            "function": {
                "name": cls.__name__,  # Function name is same as the class name, in lowercase
                "description": cls.__doc__.strip(),
                "parameters": {
                    "type": "object",
                    "properties": {},
                    "required": [],
                },
            },
        }

        # Iterate over the fields to add them to the parameters
        for field_name, field_info in cls.model_fields.items():
            
            # Field properties
            field_type = "string"  # Default to string, will adjust if it's a different type
            annotation = field_info.annotation.__args__[0] if getattr(field_info.annotation, "__origin__", None) is Union else field_info.annotation
            
            has_none = False
            if get_origin(annotation) is UnionType:  # Check if it's a Union type
                args = get_args(annotation)
                if type(None) in args:
                    has_none = True
                args = [arg for arg in args if type(None) != arg]
                if len(args) > 1:
                    raise TypeError("It can be union of only a valid type (str, int, bool, etc) and None")
                elif len(args) == 0:
                    raise TypeError("There must be a valid type (str, int, bool, etc) not only None")
                else:
                    annotation = args[0]
            
            if annotation == int:
                field_type = "integer"
            elif annotation == bool:
                field_type = "boolean"
            
            # Add the field's description and type to the properties
            function_metadata["function"]["parameters"]["properties"][field_name] = {
                "type": field_type,
                "description": field_info.description,
            }

            # Determine if the field is required (not Optional or None)
            if field_info.is_required():
                function_metadata["function"]["parameters"]["required"].append(field_name)
                has_none = True

            # If there's an enum (like for `unit`), add it to the properties
            if hasattr(field_info, 'default') and field_info.default is not None and isinstance(field_info.default, list):
                function_metadata["function"]["parameters"]["properties"][field_name]["enum"] = field_info.default
                if not has_none:
                    function_metadata["function"]["parameters"]["required"].append(field_name)



        return function_metadata



# Load the configuration file
# ===========================================================================
config_file = "config.yaml"
cfg = OmegaConf.load(config_file)


# Initialize VectorStore, tools and oitools
# ===========================================================================
vdb = VectorStore(**cfg.vdb)
tools: Dict[str, ToolBase] = {}
oitools = []



def tool_register(cls: BaseModel):
    oaitool = cls.to_openai_tool()
    
    oitools.append(oaitool)
    tools[oaitool["function"]["name"]] = cls


@tool_register
class retrieve_aina_data(ToolBase):
    """Retrieves relevant information from Aina Challenge vectorstore, based on the user's query."""
    log.info("@tool_register: retrieve_aina_data()")

    query: str = Field(description="The user's input or question, used to search in Aina Challenge vectorstore.")
    log.info(f"query: {query}")

    @classmethod
    def invoke(cls, input: Dict) -> str:
        log.info(f"retrieve_aina_data.invoke() input: {input}")
        
        # Check if the input is a dictionary
        query = input.get("query", None)
        if not query:
            return "Missing required argument: query."
        
        return vdb.get_context(query)