File size: 5,706 Bytes
5374a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import json
from pydantic import Field
from typing import Dict, Any, List
from ..core.logging import logger
from ..core.module import BaseModule
from ..core.registry import MODEL_REGISTRY, MODULE_REGISTRY
from ..models.model_configs import LLMConfig
from .operators import Operator, AnswerGenerate, QAScEnsemble
class ActionGraph(BaseModule):
name: str = Field(description="The name of the ActionGraph.")
description: str = Field(description="The description of the ActionGraph.")
llm_config: LLMConfig = Field(description="The config of LLM used to execute the ActionGraph.")
def init_module(self):
if self.llm_config:
llm_cls = MODEL_REGISTRY.get_model(self.llm_config.llm_type)
self._llm = llm_cls(config=self.llm_config)
# def __call__(self, *args: Any, **kwargs: Any) -> dict:
# return self.execute(*args, **kwargs)
def execute(self, *args, **kwargs) -> dict:
raise NotImplementedError(f"The execute function for {type(self).__name__} is not implemented!")
def async_execute(self, *args, **kwargs) -> dict:
raise NotImplementedError(f"The async_execute function for {type(self).__name__} is not implemented!")
def get_graph_info(self, **kwargs) -> dict:
"""
Get the information of the action graph, including all operators from the instance.
"""
operators = {}
# the extra fields are the fields that are not defined in the Pydantic model
for extra_name, extra_value in self.__pydantic_extra__.items():
if isinstance(extra_value, Operator):
operators[extra_name] = extra_value
config = {
"class_name": self.__class__.__name__,
"name": self.name,
"description": self.description,
"operators": {
operator_name: {
"class_name": operator.__class__.__name__,
"name": operator.name,
"description": operator.description,
"interface": operator.interface,
"prompt": operator.prompt
}
for operator_name, operator in operators.items()
}
}
return config
@classmethod
def load_module(cls, path: str, llm_config: LLMConfig = None, **kwargs) -> Dict:
"""
Load the ActionGraph from a file.
"""
assert llm_config is not None, "must provide `llm_config` when using `load_module` or `from_file` to load the ActionGraph from local storage"
action_graph_data = super().load_module(path, **kwargs)
action_graph_data["llm_config"] = llm_config.to_dict()
return action_graph_data
@classmethod
def from_dict(cls, data: Dict[str, Any], **kwargs) -> "ActionGraph":
"""
Create an ActionGraph from a dictionary.
"""
class_name = data.get("class_name", None)
if class_name:
cls = MODULE_REGISTRY.get_module(class_name)
operators_info = data.pop("operators", None)
module = cls._create_instance(data)
if operators_info:
for extra_name, extra_value in module.__pydantic_extra__.items():
if isinstance(extra_value, Operator) and extra_name in operators_info:
extra_value.set_operator(operators_info[extra_name])
return module
def save_module(self, path: str, ignore: List[str] = [], **kwargs):
"""
Save the workflow graph to a module file.
"""
logger.info("Saving {} to {}", self.__class__.__name__, path)
config = self.get_graph_info()
for ignore_key in ignore:
config.pop(ignore_key, None)
with open(path, "w", encoding="utf-8") as f:
json.dump(config, f, indent=4)
return path
def get_config(self) -> dict:
"""
Get a dictionary containing all necessary configuration to recreate this action graph.
Returns:
dict: A configuration dictionary that can be used to initialize a new ActionGraph instance
with the same properties as this one.
"""
config = self.get_graph_info()
config["llm_config"] = self.llm_config.to_dict()
return config
class QAActionGraph(ActionGraph):
def __init__(self, llm_config: LLMConfig, **kwargs):
name = kwargs.pop("name") if "name" in kwargs else "Simple QA Workflow"
description = kwargs.pop("description") if "description" in kwargs else \
"This is a simple QA workflow that use self-consistency to make predictions."
super().__init__(name=name, description=description, llm_config=llm_config, **kwargs)
self.answer_generate = AnswerGenerate(self._llm)
self.sc_ensemble = QAScEnsemble(self._llm)
def execute(self, problem: str) -> dict:
solutions = []
for _ in range(3):
response = self.answer_generate(input=problem)
answer = response["answer"]
solutions.append(answer)
ensemble_result = self.sc_ensemble(solutions=solutions)
best_answer = ensemble_result["response"]
return {"answer": best_answer}
async def async_execute(self, problem: str) -> dict:
solutions = []
for _ in range(3):
response = await self.answer_generate(input=problem)
answer = response["answer"]
solutions.append(answer)
ensemble_result = await self.sc_ensemble(solutions=solutions)
best_answer = ensemble_result["response"]
return {"answer": best_answer}
|