|
|
import yaml |
|
|
from pydantic import BaseModel, model_validator |
|
|
from typing import Optional, Type, Union, List, Any |
|
|
|
|
|
from .models.model_configs import LLMConfig |
|
|
from .core.registry import MODEL_REGISTRY |
|
|
|
|
|
|
|
|
class Config(BaseModel): |
|
|
|
|
|
llm_config: dict |
|
|
agents: Optional[Union[str, List[dict]]] = [] |
|
|
model_config = {"arbitrary_types_allowed": True, "extra": "allow", "protected_namespaces": ()} |
|
|
|
|
|
@classmethod |
|
|
def from_file(cls, path: str): |
|
|
with open(path, mode="r", encoding="utf-8") as file: |
|
|
data = yaml.safe_load(file.read()) |
|
|
config = cls.model_validate(data) |
|
|
return config |
|
|
|
|
|
@property |
|
|
def kwargs(self): |
|
|
return self.model_extra |
|
|
|
|
|
@model_validator(mode="before") |
|
|
@classmethod |
|
|
def validate_config_data(cls, data: Any) -> Any: |
|
|
|
|
|
|
|
|
llm_config_data = data.get("llm_config", None) |
|
|
if not llm_config_data: |
|
|
raise ValueError("config file must contain 'llm_config'") |
|
|
data["llm_config"] = cls.process_llm_config(data=data["llm_config"]) |
|
|
|
|
|
|
|
|
agents_data = data.get("agents", None) |
|
|
if agents_data: |
|
|
data["agents"] = cls.process_agents_data(agents=agents_data, llm_config=data["llm_config"]) |
|
|
|
|
|
return data |
|
|
|
|
|
@classmethod |
|
|
def process_llm_config(cls, data: dict) -> dict: |
|
|
|
|
|
llm_type = data.get("llm_type", None) |
|
|
if not llm_type: |
|
|
raise ValueError("must specify `llm_type` in in `llm_config`!") |
|
|
llm_config_cls: Type[LLMConfig] = MODEL_REGISTRY.get_model_config(llm_type) |
|
|
if "class_name" in data: |
|
|
assert data["class_name"] == llm_config_cls.__name__, \ |
|
|
"the 'class_name' specified in 'llm_config' ({}) doesn't match the LLMConfig class ({}) registered for {} model. You should either remove 'class_name' or set it to {}.".format( |
|
|
data["class_name"], llm_config_cls.__name__, llm_type, llm_config_cls.__name__ |
|
|
) |
|
|
else: |
|
|
data["class_name"] = llm_config_cls.__name__ |
|
|
|
|
|
return data |
|
|
|
|
|
@classmethod |
|
|
def process_agents_data(cls, agents: List[dict], llm_config=dict) -> List[dict]: |
|
|
|
|
|
for agent in agents: |
|
|
if "llm_config" not in agent: |
|
|
agent["llm_config"] = llm_config |
|
|
return agents |
|
|
|
|
|
|