Spaces:
Runtime error
Runtime error
| from langchain_core.messages import HumanMessage | |
| from langgraph.graph import END, START, StateGraph | |
| from research_assistant.app_logging import app_logger | |
| from research_assistant.components.agent import Agent | |
| from research_assistant.components.agent_tools import get_arxiv_tool, get_qa_tool | |
| from research_assistant.components.pdfParser import pdf_parser | |
| from research_assistant.components.planner import get_planner | |
| from research_assistant.components.solver import get_solver | |
| from research_assistant.components.state import ResearchSummary | |
| from research_assistant.config.configuration import ConfigurationManager | |
| from research_assistant.utils.state_utils import SummaryStateUtils | |
| class ArticleSummarization: | |
| def __init__(self, file_path): | |
| self.article_path = file_path | |
| self.config = ConfigurationManager() | |
| self.summary_utils = SummaryStateUtils() | |
| # This function gives us the model name being requested for any component in the workflow. | |
| def get_model(self, component: str): | |
| if component == "planner": | |
| config = self.config.get_planner_config() | |
| elif component == "qa_tool": | |
| config = self.config.get_qa_tool_config() | |
| elif component == "solver": | |
| config = self.config.get_solver_config() | |
| else: | |
| raise ValueError("Invalid component name for getting the Model") | |
| agent = Agent(config.model_name) | |
| return agent.get_model() | |
| # This function generates the plan for the given task using planner tool. This is attached to the planner node. | |
| def get_plan(self, state: ResearchSummary): | |
| response = get_planner(llm=self.get_model("planner")).invoke( | |
| {"article_text": state["article_text"]} | |
| ) | |
| if len(response.tools) != len(response.arguments): | |
| raise ValueError("The Plan string is not parsed properly") | |
| app_logger.info(f"The plan produced is: {response.plan_str}") | |
| return { | |
| "plan_string": response.plan_str, | |
| "dependencies": response.dependencies, | |
| "arguments": response.arguments, | |
| "tools": response.tools, | |
| } | |
| # This function executes the tools of the plan. This is attached to the tool execution node. | |
| def tool_execution(self, state: ResearchSummary): | |
| """Worker node that executes the tools of a given plan.""" | |
| current_step = self.summary_utils.get_current_task(state) | |
| arg, tools = state["arguments"], state["tools"] | |
| results_dict = (state["results"] or {}) if "results" in state else {} | |
| # Tool calling for each step. | |
| if tools[current_step - 1] == "Arxiv": | |
| result = get_arxiv_tool().run(arg[current_step - 1]) | |
| elif tools[current_step - 1] == "LLM": | |
| result = get_qa_tool(llm=self.get_model("qa_tool")).invoke( | |
| { | |
| "question": arg[current_step - 1], | |
| "context": self.summary_utils.get_current_dependencies( | |
| state, current_step | |
| ), | |
| } | |
| ) | |
| else: | |
| raise ValueError | |
| # Store the result in the results dictionary with the step number as key. | |
| results_dict[current_step] = str(result) | |
| return {"results": results_dict} | |
| # This function generates the final answer using the results obtained from tool executions. This is attached to the solve node. | |
| def solve(self, state: ResearchSummary): | |
| return { | |
| "result": get_solver(llm=self.get_model("solver")) | |
| .invoke(self.summary_utils.get_plan_results(state)) | |
| .answer | |
| } | |
| # This function builds the execution graph for the article summarization workflow. | |
| def get_graph(self): | |
| graph = StateGraph(ResearchSummary) | |
| graph.add_node("plan", self.get_plan) | |
| graph.add_node("tool", self.tool_execution) | |
| graph.add_node("solve", self.solve) | |
| graph.add_edge("plan", "tool") | |
| graph.add_edge("solve", END) | |
| graph.add_conditional_edges("tool", self.summary_utils.route) | |
| graph.add_edge(START, "plan") | |
| return graph.compile() | |
| # This function builds the execution graph for the summarization task workflow. | |
| def get_summary(self): | |
| app = self.get_graph() | |
| for s in app.stream({"article_text": pdf_parser(self.article_path)}): | |
| final_output = s | |
| return final_output["solve"]["result"] | |