| | from system_prompts import SYSTEM_PROMPT_ATTACH_FILENAME, SYSTEM_PROMPT_AGGREGATOR, SYSTEM_PROMPT_ORQ |
| |
|
| | from pydantic import BaseModel, Field |
| | from pydantic import ValidationError |
| |
|
| | from langgraph.types import Command |
| | from langgraph.graph import StateGraph, MessagesState, START, END |
| | from langchain_core.messages import ToolMessage, AIMessage, HumanMessage |
| | from langchain_google_vertexai import ChatVertexAI |
| | from langchain_anthropic import ChatAnthropic |
| | from langgraph.prebuilt import ToolNode |
| |
|
| | from typing import Literal, Optional |
| | import time |
| |
|
| | from tools import download_youtube_video, get_tools |
| |
|
| | llm_pro = ChatVertexAI(model="gemini-2.5-pro") |
| | llm_claude = ChatAnthropic(model='claude-3-5-sonnet-latest', max_retries=6) |
| | llm_tools = llm_claude.bind_tools(get_tools()) |
| |
|
| | class TaskState(MessagesState): |
| | check_final_answer: bool | None |
| | path_filename: str | None |
| | gcp_path: str | None |
| | final_answer: str | None |
| | explanation: str | None |
| |
|
| | class RouterFilename(BaseModel): |
| | is_filename_attached: bool = Field(..., description="Whether or not there is a file or link associated with data to be analysed at the user's request.") |
| | data_type: Literal["code", "data", "youtube", "audio", "image", "none"] = Field(..., description="Type of file attached to the task") |
| | youtube_url: Optional[str] = Field( |
| | default=None, |
| | description="Youtube URL attached to the user's order, if any." |
| | ) |
| |
|
| | class Answer(BaseModel): |
| | final_answer: Optional[str] = Field( |
| | default=None, |
| | description="Final response for the user" |
| | ) |
| |
|
| | explanation: Optional[str] = Field( |
| | default=None, |
| | description="Explanation of the final response" |
| | ) |
| |
|
| | def attach_data(state: TaskState) -> dict: |
| | messages = [ |
| | {"role": "system", |
| | "content": SYSTEM_PROMPT_ATTACH_FILENAME} |
| | ] + state["messages"] |
| |
|
| | generator = llm_pro.with_structured_output(RouterFilename) |
| |
|
| | for _ in range(3): |
| | try: |
| | router_decision = generator.invoke(messages) |
| | if router_decision is not None: |
| | break |
| | except ValidationError as err: |
| | messages.append({"role": "system", "content": |
| | "This JSON is not valid! Please, try again."}) |
| | time.sleep(2.0) |
| | else: |
| | raise RuntimeError("Gemini didn't get the structured output.") |
| | |
| | print(f"Router filename decision: {router_decision}") |
| | if router_decision.is_filename_attached: |
| | filename_type = router_decision.data_type |
| | if filename_type in ("code", "data"): |
| | path_filename = state["path_filename"] |
| | if filename_type == 'code': |
| | with open(state["path_filename"], "r", encoding="utf-8") as f: |
| | code = f.read() |
| | |
| | response = f"Code:\n```python\n{code}\n```" |
| | else: |
| | response = f"Path of the attached file: {path_filename}" |
| | |
| | elif filename_type == 'youtube': |
| | if state.get('gcp_path'): |
| | gcp_path = state["gcp_path"] |
| | else: |
| | _, gcp_path = download_youtube_video(router_decision.youtube_url, "video") |
| | |
| | response = f"video GCP uri: {gcp_path}" |
| |
|
| | elif filename_type == 'audio': |
| | gcp_path = state["gcp_path"] |
| | response = f"audio GCP uri: {gcp_path}" |
| |
|
| | else: |
| | gcp_path = state["gcp_path"] |
| | response = f"image GCP uri: {gcp_path}" |
| | |
| |
|
| | |
| | return {"messages": state["messages"] + [response]} |
| |
|
| | return {} |
| |
|
| | def manager(state: TaskState) -> dict: |
| | messages = [ |
| | {"role": "system", |
| | "content": SYSTEM_PROMPT_ORQ} |
| | ] + state["messages"] |
| | |
| | response = llm_tools.invoke(messages) |
| | print(f"LLM ORQ response: {response}") |
| |
|
| | |
| | if not response.tool_calls and "FINAL_ANSER" in response.content: |
| | return {"messages": state["messages"] + [response], "check_final_anser": True} |
| |
|
| | return {"messages": state["messages"] + [response]} |
| |
|
| | def next_node_router(state: TaskState) -> Literal[ |
| | "tool_node", "aggregator" |
| | ]: |
| | if state["check_final_answer"]: |
| | return "aggregator" |
| |
|
| | |
| | last_message = state["messages"][-1] |
| | if isinstance(last_message, AIMessage) and last_message.tool_calls: |
| | return "tool_node" |
| |
|
| | return "aggregator" |
| |
|
| | def aggregator(state: TaskState) -> dict: |
| | task = state["messages"][0].content |
| | last_model_answer = state["messages"][-1].content |
| | |
| | content = f""" |
| | Task: {task} |
| | {last_model_answer} |
| | """ |
| | message_last = HumanMessage(content=content) |
| | |
| | messages = [ |
| | {"role": "system", |
| | "content": SYSTEM_PROMPT_AGGREGATOR} |
| | ] + [message_last] |
| |
|
| | generator = llm_pro.with_structured_output(Answer) |
| |
|
| | for _ in range(3): |
| | try: |
| | response = generator.invoke(messages) |
| | if response is not None: |
| | break |
| | except ValidationError as err: |
| | messages.append({"role": "system", "content": |
| | "This JSON is not valid! Please, try again."}) |
| | time.sleep(2.0) |
| | else: |
| | raise RuntimeError("Gemini didn't get the structured output.") |
| |
|
| | return {"final_answer": response.final_answer, "explanation": response.explanation} |
| |
|
| |
|
| | def generate_graph(): |
| | tool_node = ToolNode(get_tools()) |
| |
|
| | builder = StateGraph(TaskState) |
| |
|
| | |
| | builder.add_node("attach_data", attach_data) |
| | builder.add_node("manager", manager) |
| | builder.add_node("tool_node", tool_node) |
| | builder.add_node("aggregator", aggregator) |
| |
|
| | |
| | builder.add_edge(START, "attach_data") |
| | builder.add_edge("attach_data", "manager") |
| |
|
| | |
| | builder.add_edge("tool_node", "manager") |
| |
|
| | |
| | builder.add_conditional_edges( |
| | "manager", |
| | next_node_router, |
| | |
| | { |
| | "tool_node": "tool_node", |
| | "aggregator": "aggregator" |
| | } |
| | ) |
| |
|
| | graph = builder.compile() |
| |
|
| | return graph |