File size: 6,603 Bytes
9b916a5 553f789 9b916a5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | from system_prompts import SYSTEM_PROMPT_ATTACH_FILENAME, SYSTEM_PROMPT_AGGREGATOR, SYSTEM_PROMPT_ORQ
from pydantic import BaseModel, Field
from pydantic import ValidationError
from langgraph.types import Command
from langgraph.graph import StateGraph, MessagesState, START, END
from langchain_core.messages import ToolMessage, AIMessage, HumanMessage
from langchain_google_vertexai import ChatVertexAI
from langchain_anthropic import ChatAnthropic
from langgraph.prebuilt import ToolNode
from typing import Literal, Optional
import time
from tools import download_youtube_video, get_tools
llm_pro = ChatVertexAI(model="gemini-2.5-pro")
llm_claude = ChatAnthropic(model='claude-3-5-sonnet-latest', max_retries=6)
llm_tools = llm_claude.bind_tools(get_tools())
class TaskState(MessagesState): # inherits the standard “messages” list
check_final_answer: bool | None
path_filename: str | None
gcp_path: str | None
final_answer: str | None
explanation: str | None
class RouterFilename(BaseModel):
is_filename_attached: bool = Field(..., description="Whether or not there is a file or link associated with data to be analysed at the user's request.")
data_type: Literal["code", "data", "youtube", "audio", "image", "none"] = Field(..., description="Type of file attached to the task")
youtube_url: Optional[str] = Field(
default=None,
description="Youtube URL attached to the user's order, if any."
)
class Answer(BaseModel):
final_answer: Optional[str] = Field(
default=None,
description="Final response for the user"
)
explanation: Optional[str] = Field(
default=None,
description="Explanation of the final response"
)
def attach_data(state: TaskState) -> dict:
messages = [
{"role": "system",
"content": SYSTEM_PROMPT_ATTACH_FILENAME}
] + state["messages"]
generator = llm_pro.with_structured_output(RouterFilename)
for _ in range(3): # 3 reintentos lógicos
try:
router_decision = generator.invoke(messages)
if router_decision is not None:
break
except ValidationError as err:
messages.append({"role": "system", "content":
"This JSON is not valid! Please, try again."})
time.sleep(2.0)
else:
raise RuntimeError("Gemini didn't get the structured output.")
print(f"Router filename decision: {router_decision}")
if router_decision.is_filename_attached:
filename_type = router_decision.data_type
if filename_type in ("code", "data"):
path_filename = state["path_filename"]
if filename_type == 'code':
with open(state["path_filename"], "r", encoding="utf-8") as f:
code = f.read()
response = f"Code:\n```python\n{code}\n```"
else:
response = f"Path of the attached file: {path_filename}"
elif filename_type == 'youtube':
if state.get('gcp_path'):
gcp_path = state["gcp_path"]
else:
_, gcp_path = download_youtube_video(router_decision.youtube_url, "video")
response = f"video GCP uri: {gcp_path}"
elif filename_type == 'audio':
gcp_path = state["gcp_path"]
response = f"audio GCP uri: {gcp_path}"
else:
gcp_path = state["gcp_path"]
response = f"image GCP uri: {gcp_path}"
#pdb.set_trace()
return {"messages": state["messages"] + [response]}
return {}
def manager(state: TaskState) -> dict:
messages = [
{"role": "system",
"content": SYSTEM_PROMPT_ORQ}
] + state["messages"]
response = llm_tools.invoke(messages)
print(f"LLM ORQ response: {response}")
#suponemos que esto tiene que ser la respuesta final
if not response.tool_calls and "FINAL_ANSER" in response.content:
return {"messages": state["messages"] + [response], "check_final_anser": True}
return {"messages": state["messages"] + [response]}
def next_node_router(state: TaskState) -> Literal[
"tool_node", "aggregator"
]:
if state["check_final_answer"]:
return "aggregator"
# Inspeccionamos el último mensaje del historial
last_message = state["messages"][-1]
if isinstance(last_message, AIMessage) and last_message.tool_calls:
return "tool_node"
return "aggregator"
def aggregator(state: TaskState) -> dict:
task = state["messages"][0].content
last_model_answer = state["messages"][-1].content
content = f"""
Task: {task}
{last_model_answer}
"""
message_last = HumanMessage(content=content)
messages = [
{"role": "system",
"content": SYSTEM_PROMPT_AGGREGATOR}
] + [message_last]
generator = llm_pro.with_structured_output(Answer)
for _ in range(3): # 3 reintentos lógicos
try:
response = generator.invoke(messages)
if response is not None: # lista no vacía
break
except ValidationError as err:
messages.append({"role": "system", "content":
"This JSON is not valid! Please, try again."})
time.sleep(2.0)
else:
raise RuntimeError("Gemini didn't get the structured output.")
return {"final_answer": response.final_answer, "explanation": response.explanation}
def generate_graph():
tool_node = ToolNode(get_tools())
builder = StateGraph(TaskState)
# Añadimos todos los nodos, incluyendo el nuevo tool_node
builder.add_node("attach_data", attach_data)
builder.add_node("manager", manager)
builder.add_node("tool_node", tool_node) # NUEVO
builder.add_node("aggregator", aggregator)
# El manager es el punto de partida
builder.add_edge(START, "attach_data")
builder.add_edge("attach_data", "manager")
# Después de ejecutar una herramienta, vuelve al manager con el resultado
builder.add_edge("tool_node", "manager")
# El manager ahora usa un enrutador condicional para decidir el siguiente gran paso
builder.add_conditional_edges(
"manager",
next_node_router,
# El mapeo ahora es más simple gracias a la lógica en next_node_router
{
"tool_node": "tool_node",
"aggregator": "aggregator"
}
)
graph = builder.compile()
return graph |