|
|
import json |
|
|
from pydantic import Field |
|
|
from typing import Dict, Any, List |
|
|
|
|
|
from ..core.logging import logger |
|
|
from ..core.module import BaseModule |
|
|
from ..core.registry import MODEL_REGISTRY, MODULE_REGISTRY |
|
|
from ..models.model_configs import LLMConfig |
|
|
from .operators import Operator, AnswerGenerate, QAScEnsemble |
|
|
|
|
|
|
|
|
class ActionGraph(BaseModule): |
|
|
|
|
|
name: str = Field(description="The name of the ActionGraph.") |
|
|
description: str = Field(description="The description of the ActionGraph.") |
|
|
llm_config: LLMConfig = Field(description="The config of LLM used to execute the ActionGraph.") |
|
|
|
|
|
def init_module(self): |
|
|
if self.llm_config: |
|
|
llm_cls = MODEL_REGISTRY.get_model(self.llm_config.llm_type) |
|
|
self._llm = llm_cls(config=self.llm_config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def execute(self, *args, **kwargs) -> dict: |
|
|
raise NotImplementedError(f"The execute function for {type(self).__name__} is not implemented!") |
|
|
|
|
|
def async_execute(self, *args, **kwargs) -> dict: |
|
|
raise NotImplementedError(f"The async_execute function for {type(self).__name__} is not implemented!") |
|
|
|
|
|
def get_graph_info(self, **kwargs) -> dict: |
|
|
""" |
|
|
Get the information of the action graph, including all operators from the instance. |
|
|
""" |
|
|
operators = {} |
|
|
|
|
|
for extra_name, extra_value in self.__pydantic_extra__.items(): |
|
|
if isinstance(extra_value, Operator): |
|
|
operators[extra_name] = extra_value |
|
|
|
|
|
config = { |
|
|
"class_name": self.__class__.__name__, |
|
|
"name": self.name, |
|
|
"description": self.description, |
|
|
"operators": { |
|
|
operator_name: { |
|
|
"class_name": operator.__class__.__name__, |
|
|
"name": operator.name, |
|
|
"description": operator.description, |
|
|
"interface": operator.interface, |
|
|
"prompt": operator.prompt |
|
|
} |
|
|
for operator_name, operator in operators.items() |
|
|
} |
|
|
} |
|
|
return config |
|
|
|
|
|
@classmethod |
|
|
def load_module(cls, path: str, llm_config: LLMConfig = None, **kwargs) -> Dict: |
|
|
""" |
|
|
Load the ActionGraph from a file. |
|
|
""" |
|
|
assert llm_config is not None, "must provide `llm_config` when using `load_module` or `from_file` to load the ActionGraph from local storage" |
|
|
action_graph_data = super().load_module(path, **kwargs) |
|
|
action_graph_data["llm_config"] = llm_config.to_dict() |
|
|
return action_graph_data |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, data: Dict[str, Any], **kwargs) -> "ActionGraph": |
|
|
""" |
|
|
Create an ActionGraph from a dictionary. |
|
|
""" |
|
|
class_name = data.get("class_name", None) |
|
|
if class_name: |
|
|
cls = MODULE_REGISTRY.get_module(class_name) |
|
|
operators_info = data.pop("operators", None) |
|
|
module = cls._create_instance(data) |
|
|
if operators_info: |
|
|
for extra_name, extra_value in module.__pydantic_extra__.items(): |
|
|
if isinstance(extra_value, Operator) and extra_name in operators_info: |
|
|
extra_value.set_operator(operators_info[extra_name]) |
|
|
return module |
|
|
|
|
|
def save_module(self, path: str, ignore: List[str] = [], **kwargs): |
|
|
""" |
|
|
Save the workflow graph to a module file. |
|
|
""" |
|
|
logger.info("Saving {} to {}", self.__class__.__name__, path) |
|
|
config = self.get_graph_info() |
|
|
for ignore_key in ignore: |
|
|
config.pop(ignore_key, None) |
|
|
with open(path, "w", encoding="utf-8") as f: |
|
|
json.dump(config, f, indent=4) |
|
|
|
|
|
return path |
|
|
|
|
|
def get_config(self) -> dict: |
|
|
""" |
|
|
Get a dictionary containing all necessary configuration to recreate this action graph. |
|
|
|
|
|
Returns: |
|
|
dict: A configuration dictionary that can be used to initialize a new ActionGraph instance |
|
|
with the same properties as this one. |
|
|
""" |
|
|
config = self.get_graph_info() |
|
|
config["llm_config"] = self.llm_config.to_dict() |
|
|
return config |
|
|
|
|
|
class QAActionGraph(ActionGraph): |
|
|
|
|
|
def __init__(self, llm_config: LLMConfig, **kwargs): |
|
|
|
|
|
name = kwargs.pop("name") if "name" in kwargs else "Simple QA Workflow" |
|
|
description = kwargs.pop("description") if "description" in kwargs else \ |
|
|
"This is a simple QA workflow that use self-consistency to make predictions." |
|
|
super().__init__(name=name, description=description, llm_config=llm_config, **kwargs) |
|
|
self.answer_generate = AnswerGenerate(self._llm) |
|
|
self.sc_ensemble = QAScEnsemble(self._llm) |
|
|
|
|
|
def execute(self, problem: str) -> dict: |
|
|
|
|
|
solutions = [] |
|
|
for _ in range(3): |
|
|
response = self.answer_generate(input=problem) |
|
|
answer = response["answer"] |
|
|
solutions.append(answer) |
|
|
ensemble_result = self.sc_ensemble(solutions=solutions) |
|
|
best_answer = ensemble_result["response"] |
|
|
return {"answer": best_answer} |
|
|
|
|
|
async def async_execute(self, problem: str) -> dict: |
|
|
solutions = [] |
|
|
for _ in range(3): |
|
|
response = await self.answer_generate(input=problem) |
|
|
answer = response["answer"] |
|
|
solutions.append(answer) |
|
|
ensemble_result = await self.sc_ensemble(solutions=solutions) |
|
|
best_answer = ensemble_result["response"] |
|
|
return {"answer": best_answer} |
|
|
|