Spaces:
Sleeping
Sleeping
File size: 4,004 Bytes
69601d4 1b179fe 69601d4 d7fd1be 69601d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
from ..components.failsafeAgent import FailsafeCodeGenerator
from ..components.queryRephraserAgent import QueryRephaser
from ..components.codeGeneratorAgent import CodeGenerator
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
from ..components import replManager
import json
failsafeCodeGeneratorChain = FailsafeCodeGenerator().getFailsafeCodeGeneratorChain()
queryRephraseChain = QueryRephaser().getQueryRephraserChain()
codeGeneratorChain = CodeGenerator().getCodeGeneratorChain()
class State(TypedDict):
projectId: str
inputQuery: str
metadata: str
rephrasedQuery: str
generatedCode: str
codeOutput: str
finalOutput: dict
class ReportingToolWorkflow:
def __init__(self):
pass
def rephraseQuery(self, state: State):
response = queryRephraseChain.invoke({
"query": state["inputQuery"],
"metadata": state["metadata"]
})
return {
"rephrasedQuery": response
}
def generateCode(self, state: State):
response = codeGeneratorChain.invoke({
"query": state["rephrasedQuery"],
"metadata": state["metadata"]
})
return {
"generatedCode": f'fetch_data("{state["projectId"]}", '.join(response.split("fetch_data(")).replace('indent=4', 'default=serializer')
}
def runInPythonSandbox(self, state: State):
code = "\n".join(state["generatedCode"].split("```")[-2].split("\n")[1:])
response = replManager.run(code)
return {
"codeOutput": response
}
def outputEvaluationRouter(self, state: State):
try:
_ = json.loads(state["codeOutput"])
return "pass"
except json.JSONDecodeError:
return "fail"
def failsafe(self, state: State):
response = failsafeCodeGeneratorChain.invoke({
"user_query": state["rephrasedQuery"],
"metadata_context": state["metadata"],
"code_with_errors": state["generatedCode"],
"error_message": state["codeOutput"]
})
return {
"generatedCode": response
}
def formatJsonResponse(self, state: State):
if "codeOutput" in state.keys():
try:
response = json.loads(state["codeOutput"])
except Exception as e:
response = {"error": f"Endpoint says: {e}"}
return {
"finalOutput": response
}
else:
return {
"finalOutput": {"response": state["rephrasedQuery"]["doubt"]}
}
def router(self, state: State):
if state["rephrasedQuery"]["doubt"] == None:
return "continue"
else:
return "interrupt"
def createWorkflow(self):
workflow = StateGraph(State)
workflow.add_node("rephraseQuery", self.rephraseQuery)
workflow.add_node("generateCode", self.generateCode)
workflow.add_node("runInPythonSandbox", self.runInPythonSandbox)
workflow.add_node("failsafe", self.failsafe)
workflow.add_node("failsafePythonSandbox", self.runInPythonSandbox)
workflow.add_node("formatJsonResponse", self.formatJsonResponse)
workflow.add_edge(START, "rephraseQuery")
workflow.add_conditional_edges("rephraseQuery", self.router, {"continue": "generateCode", "interrupt": "formatJsonResponse"})
workflow.add_edge("generateCode", "runInPythonSandbox")
workflow.add_conditional_edges("runInPythonSandbox", self.outputEvaluationRouter, {"pass": "formatJsonResponse", "fail": "failsafe"})
workflow.add_edge("failsafe", "failsafePythonSandbox")
workflow.add_edge("failsafePythonSandbox", "formatJsonResponse")
workflow.add_edge("formatJsonResponse", END)
workflow = workflow.compile()
return workflow
graph = ReportingToolWorkflow()
reportingToolWorkflow = graph.createWorkflow() |