pedutronix's picture
fix youtube in production and add more recursion limit
553f789
from system_prompts import SYSTEM_PROMPT_ATTACH_FILENAME, SYSTEM_PROMPT_AGGREGATOR, SYSTEM_PROMPT_ORQ
from pydantic import BaseModel, Field
from pydantic import ValidationError
from langgraph.types import Command
from langgraph.graph import StateGraph, MessagesState, START, END
from langchain_core.messages import ToolMessage, AIMessage, HumanMessage
from langchain_google_vertexai import ChatVertexAI
from langchain_anthropic import ChatAnthropic
from langgraph.prebuilt import ToolNode
from typing import Literal, Optional
import time
from tools import download_youtube_video, get_tools
llm_pro = ChatVertexAI(model="gemini-2.5-pro")
llm_claude = ChatAnthropic(model='claude-3-5-sonnet-latest', max_retries=6)
llm_tools = llm_claude.bind_tools(get_tools())
class TaskState(MessagesState): # inherits the standard “messages” list
check_final_answer: bool | None
path_filename: str | None
gcp_path: str | None
final_answer: str | None
explanation: str | None
class RouterFilename(BaseModel):
is_filename_attached: bool = Field(..., description="Whether or not there is a file or link associated with data to be analysed at the user's request.")
data_type: Literal["code", "data", "youtube", "audio", "image", "none"] = Field(..., description="Type of file attached to the task")
youtube_url: Optional[str] = Field(
default=None,
description="Youtube URL attached to the user's order, if any."
)
class Answer(BaseModel):
final_answer: Optional[str] = Field(
default=None,
description="Final response for the user"
)
explanation: Optional[str] = Field(
default=None,
description="Explanation of the final response"
)
def attach_data(state: TaskState) -> dict:
messages = [
{"role": "system",
"content": SYSTEM_PROMPT_ATTACH_FILENAME}
] + state["messages"]
generator = llm_pro.with_structured_output(RouterFilename)
for _ in range(3): # 3 reintentos lógicos
try:
router_decision = generator.invoke(messages)
if router_decision is not None:
break
except ValidationError as err:
messages.append({"role": "system", "content":
"This JSON is not valid! Please, try again."})
time.sleep(2.0)
else:
raise RuntimeError("Gemini didn't get the structured output.")
print(f"Router filename decision: {router_decision}")
if router_decision.is_filename_attached:
filename_type = router_decision.data_type
if filename_type in ("code", "data"):
path_filename = state["path_filename"]
if filename_type == 'code':
with open(state["path_filename"], "r", encoding="utf-8") as f:
code = f.read()
response = f"Code:\n```python\n{code}\n```"
else:
response = f"Path of the attached file: {path_filename}"
elif filename_type == 'youtube':
if state.get('gcp_path'):
gcp_path = state["gcp_path"]
else:
_, gcp_path = download_youtube_video(router_decision.youtube_url, "video")
response = f"video GCP uri: {gcp_path}"
elif filename_type == 'audio':
gcp_path = state["gcp_path"]
response = f"audio GCP uri: {gcp_path}"
else:
gcp_path = state["gcp_path"]
response = f"image GCP uri: {gcp_path}"
#pdb.set_trace()
return {"messages": state["messages"] + [response]}
return {}
def manager(state: TaskState) -> dict:
messages = [
{"role": "system",
"content": SYSTEM_PROMPT_ORQ}
] + state["messages"]
response = llm_tools.invoke(messages)
print(f"LLM ORQ response: {response}")
#suponemos que esto tiene que ser la respuesta final
if not response.tool_calls and "FINAL_ANSER" in response.content:
return {"messages": state["messages"] + [response], "check_final_anser": True}
return {"messages": state["messages"] + [response]}
def next_node_router(state: TaskState) -> Literal[
"tool_node", "aggregator"
]:
if state["check_final_answer"]:
return "aggregator"
# Inspeccionamos el último mensaje del historial
last_message = state["messages"][-1]
if isinstance(last_message, AIMessage) and last_message.tool_calls:
return "tool_node"
return "aggregator"
def aggregator(state: TaskState) -> dict:
task = state["messages"][0].content
last_model_answer = state["messages"][-1].content
content = f"""
Task: {task}
{last_model_answer}
"""
message_last = HumanMessage(content=content)
messages = [
{"role": "system",
"content": SYSTEM_PROMPT_AGGREGATOR}
] + [message_last]
generator = llm_pro.with_structured_output(Answer)
for _ in range(3): # 3 reintentos lógicos
try:
response = generator.invoke(messages)
if response is not None: # lista no vacía
break
except ValidationError as err:
messages.append({"role": "system", "content":
"This JSON is not valid! Please, try again."})
time.sleep(2.0)
else:
raise RuntimeError("Gemini didn't get the structured output.")
return {"final_answer": response.final_answer, "explanation": response.explanation}
def generate_graph():
tool_node = ToolNode(get_tools())
builder = StateGraph(TaskState)
# Añadimos todos los nodos, incluyendo el nuevo tool_node
builder.add_node("attach_data", attach_data)
builder.add_node("manager", manager)
builder.add_node("tool_node", tool_node) # NUEVO
builder.add_node("aggregator", aggregator)
# El manager es el punto de partida
builder.add_edge(START, "attach_data")
builder.add_edge("attach_data", "manager")
# Después de ejecutar una herramienta, vuelve al manager con el resultado
builder.add_edge("tool_node", "manager")
# El manager ahora usa un enrutador condicional para decidir el siguiente gran paso
builder.add_conditional_edges(
"manager",
next_node_router,
# El mapeo ahora es más simple gracias a la lógica en next_node_router
{
"tool_node": "tool_node",
"aggregator": "aggregator"
}
)
graph = builder.compile()
return graph