| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import importlib |
| import json |
| import traceback |
| from abc import ABC |
| from copy import deepcopy |
| from functools import partial |
|
|
| import pandas as pd |
|
|
| from agent.component import component_class |
| from agent.component.base import ComponentBase |
| from agent.settings import flow_logger, DEBUG |
|
|
|
|
| class Canvas(ABC): |
| """ |
| dsl = { |
| "components": { |
| "begin": { |
| "obj":{ |
| "component_name": "Begin", |
| "params": {}, |
| }, |
| "downstream": ["answer_0"], |
| "upstream": [], |
| }, |
| "answer_0": { |
| "obj": { |
| "component_name": "Answer", |
| "params": {} |
| }, |
| "downstream": ["retrieval_0"], |
| "upstream": ["begin", "generate_0"], |
| }, |
| "retrieval_0": { |
| "obj": { |
| "component_name": "Retrieval", |
| "params": {} |
| }, |
| "downstream": ["generate_0"], |
| "upstream": ["answer_0"], |
| }, |
| "generate_0": { |
| "obj": { |
| "component_name": "Generate", |
| "params": {} |
| }, |
| "downstream": ["answer_0"], |
| "upstream": ["retrieval_0"], |
| } |
| }, |
| "history": [], |
| "messages": [], |
| "reference": [], |
| "path": [["begin"]], |
| "answer": [] |
| } |
| """ |
|
|
| def __init__(self, dsl: str, tenant_id=None): |
| self.path = [] |
| self.history = [] |
| self.messages = [] |
| self.answer = [] |
| self.components = {} |
| self.dsl = json.loads(dsl) if dsl else { |
| "components": { |
| "begin": { |
| "obj": { |
| "component_name": "Begin", |
| "params": { |
| "prologue": "Hi there!" |
| } |
| }, |
| "downstream": [], |
| "upstream": [] |
| } |
| }, |
| "history": [], |
| "messages": [], |
| "reference": [], |
| "path": [], |
| "answer": [] |
| } |
| self._tenant_id = tenant_id |
| self._embed_id = "" |
| self.load() |
|
|
| def load(self): |
| self.components = self.dsl["components"] |
| cpn_nms = set([]) |
| for k, cpn in self.components.items(): |
| cpn_nms.add(cpn["obj"]["component_name"]) |
|
|
| assert "Begin" in cpn_nms, "There have to be an 'Begin' component." |
| assert "Answer" in cpn_nms, "There have to be an 'Answer' component." |
|
|
| for k, cpn in self.components.items(): |
| cpn_nms.add(cpn["obj"]["component_name"]) |
| param = component_class(cpn["obj"]["component_name"] + "Param")() |
| param.update(cpn["obj"]["params"]) |
| param.check() |
| cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param) |
| if cpn["obj"].component_name == "Categorize": |
| for _, desc in param.category_description.items(): |
| if desc["to"] not in cpn["downstream"]: |
| cpn["downstream"].append(desc["to"]) |
|
|
| self.path = self.dsl["path"] |
| self.history = self.dsl["history"] |
| self.messages = self.dsl["messages"] |
| self.answer = self.dsl["answer"] |
| self.reference = self.dsl["reference"] |
| self._embed_id = self.dsl.get("embed_id", "") |
|
|
| def __str__(self): |
| self.dsl["path"] = self.path |
| self.dsl["history"] = self.history |
| self.dsl["messages"] = self.messages |
| self.dsl["answer"] = self.answer |
| self.dsl["reference"] = self.reference |
| self.dsl["embed_id"] = self._embed_id |
| dsl = { |
| "components": {} |
| } |
| for k in self.dsl.keys(): |
| if k in ["components"]:continue |
| dsl[k] = deepcopy(self.dsl[k]) |
|
|
| for k, cpn in self.components.items(): |
| if k not in dsl["components"]: |
| dsl["components"][k] = {} |
| for c in cpn.keys(): |
| if c == "obj": |
| dsl["components"][k][c] = json.loads(str(cpn["obj"])) |
| continue |
| dsl["components"][k][c] = deepcopy(cpn[c]) |
| return json.dumps(dsl, ensure_ascii=False) |
|
|
| def reset(self): |
| self.path = [] |
| self.history = [] |
| self.messages = [] |
| self.answer = [] |
| self.reference = [] |
| for k, cpn in self.components.items(): |
| self.components[k]["obj"].reset() |
| self._embed_id = "" |
|
|
| def run(self, **kwargs): |
| ans = "" |
| if self.answer: |
| cpn_id = self.answer[0] |
| self.answer.pop(0) |
| try: |
| ans = self.components[cpn_id]["obj"].run(self.history, **kwargs) |
| except Exception as e: |
| ans = ComponentBase.be_output(str(e)) |
| self.path[-1].append(cpn_id) |
| if kwargs.get("stream"): |
| assert isinstance(ans, partial) |
| return ans |
| self.history.append(("assistant", ans.to_dict("records"))) |
| return ans |
|
|
| if not self.path: |
| self.components["begin"]["obj"].run(self.history, **kwargs) |
| self.path.append(["begin"]) |
|
|
| self.path.append([]) |
| ran = -1 |
|
|
| def prepare2run(cpns): |
| nonlocal ran, ans |
| for c in cpns: |
| if self.path[-1] and c == self.path[-1][-1]: continue |
| cpn = self.components[c]["obj"] |
| if cpn.component_name == "Answer": |
| self.answer.append(c) |
| else: |
| if DEBUG: print("RUN: ", c) |
| if cpn.component_name == "Generate": |
| cpids = cpn.get_dependent_components() |
| if any([c not in self.path[-1] for c in cpids]): |
| continue |
| ans = cpn.run(self.history, **kwargs) |
| self.path[-1].append(c) |
| ran += 1 |
|
|
| prepare2run(self.components[self.path[-2][-1]]["downstream"]) |
| while 0 <= ran < len(self.path[-1]): |
| if DEBUG: print(ran, self.path) |
| cpn_id = self.path[-1][ran] |
| cpn = self.get_component(cpn_id) |
| if not cpn["downstream"]: break |
|
|
| loop = self._find_loop() |
| if loop: raise OverflowError(f"Too much loops: {loop}") |
|
|
| if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]: |
| switch_out = cpn["obj"].output()[1].iloc[0, 0] |
| assert switch_out in self.components, \ |
| "{}'s output: {} not valid.".format(cpn_id, switch_out) |
| try: |
| prepare2run([switch_out]) |
| except Exception as e: |
| for p in [c for p in self.path for c in p][::-1]: |
| if p.lower().find("answer") >= 0: |
| self.get_component(p)["obj"].set_exception(e) |
| prepare2run([p]) |
| break |
| traceback.print_exc() |
| break |
| continue |
|
|
| try: |
| prepare2run(cpn["downstream"]) |
| except Exception as e: |
| for p in [c for p in self.path for c in p][::-1]: |
| if p.lower().find("answer") >= 0: |
| self.get_component(p)["obj"].set_exception(e) |
| prepare2run([p]) |
| break |
| traceback.print_exc() |
| break |
|
|
| if self.answer: |
| cpn_id = self.answer[0] |
| self.answer.pop(0) |
| ans = self.components[cpn_id]["obj"].run(self.history, **kwargs) |
| self.path[-1].append(cpn_id) |
| if kwargs.get("stream"): |
| assert isinstance(ans, partial) |
| return ans |
|
|
| self.history.append(("assistant", ans.to_dict("records"))) |
|
|
| return ans |
|
|
| def get_component(self, cpn_id): |
| return self.components[cpn_id] |
|
|
| def get_tenant_id(self): |
| return self._tenant_id |
|
|
| def get_history(self, window_size): |
| convs = [] |
| for role, obj in self.history[(window_size + 1) * -1:]: |
| convs.append({"role": role, "content": (obj if role == "user" else |
| '\n'.join(pd.DataFrame(obj)['content']))}) |
| return convs |
|
|
| def add_user_input(self, question): |
| self.history.append(("user", question)) |
|
|
| def set_embedding_model(self, embed_id): |
| self._embed_id = embed_id |
|
|
| def get_embedding_model(self): |
| return self._embed_id |
|
|
| def _find_loop(self, max_loops=6): |
| path = self.path[-1][::-1] |
| if len(path) < 2: return False |
|
|
| for i in range(len(path)): |
| if path[i].lower().find("answer") >= 0: |
| path = path[:i] |
| break |
|
|
| if len(path) < 2: return False |
|
|
| for l in range(2, len(path) // 2): |
| pat = ",".join(path[0:l]) |
| path_str = ",".join(path) |
| if len(pat) >= len(path_str): return False |
| loop = max_loops |
| while path_str.find(pat) == 0 and loop >= 0: |
| loop -= 1 |
| if len(pat)+1 >= len(path_str): |
| return False |
| path_str = path_str[len(pat)+1:] |
| if loop < 0: |
| pat = " => ".join([p.split(":")[0] for p in path[0:l]]) |
| return pat + " => " + pat |
|
|
| return False |
|
|
| def get_prologue(self): |
| return self.components["begin"]["obj"]._param.prologue |
|
|