File size: 2,407 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import yaml
from pydantic import BaseModel, model_validator
from typing import Optional, Type, Union, List, Any
# from .core.base_config import BaseConfig
from .models.model_configs import LLMConfig
from .core.registry import MODEL_REGISTRY


class Config(BaseModel):

    llm_config: dict
    agents: Optional[Union[str, List[dict]]] = []
    model_config = {"arbitrary_types_allowed": True, "extra": "allow", "protected_namespaces": ()}

    @classmethod
    def from_file(cls, path: str):
        with open(path, mode="r", encoding="utf-8") as file:
            data = yaml.safe_load(file.read())
        config = cls.model_validate(data)
        return config
    
    @property
    def kwargs(self):
        return self.model_extra

    @model_validator(mode="before")
    @classmethod
    def validate_config_data(cls, data: Any) -> Any:

        # process llm config
        llm_config_data = data.get("llm_config", None)
        if not llm_config_data:
            raise ValueError("config file must contain 'llm_config'")
        data["llm_config"] = cls.process_llm_config(data=data["llm_config"])
        
        # process agent data
        agents_data = data.get("agents", None)
        if agents_data:
            data["agents"] = cls.process_agents_data(agents=agents_data, llm_config=data["llm_config"])

        return data

    @classmethod
    def process_llm_config(cls, data: dict) -> dict:

        llm_type = data.get("llm_type", None)
        if not llm_type:
            raise ValueError("must specify `llm_type` in in `llm_config`!")
        llm_config_cls: Type[LLMConfig] = MODEL_REGISTRY.get_model_config(llm_type)
        if "class_name" in data:
            assert data["class_name"] == llm_config_cls.__name__, \
                "the 'class_name' specified in 'llm_config' ({}) doesn't match the LLMConfig class ({}) registered for {} model. You should either remove 'class_name' or set it to {}.".format(
                    data["class_name"], llm_config_cls.__name__, llm_type, llm_config_cls.__name__
                )
        else:
            data["class_name"] = llm_config_cls.__name__
        
        return data
    
    @classmethod
    def process_agents_data(cls, agents: List[dict], llm_config=dict) -> List[dict]:

        for agent in agents:
            if "llm_config" not in agent:
                agent["llm_config"] = llm_config
        return agents