File size: 6,603 Bytes
9b916a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553f789
 
 
 
 
9b916a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from system_prompts import SYSTEM_PROMPT_ATTACH_FILENAME, SYSTEM_PROMPT_AGGREGATOR, SYSTEM_PROMPT_ORQ

from pydantic import BaseModel, Field
from pydantic import ValidationError

from langgraph.types import Command
from langgraph.graph import StateGraph, MessagesState, START, END
from langchain_core.messages import ToolMessage, AIMessage, HumanMessage
from langchain_google_vertexai import ChatVertexAI
from langchain_anthropic import ChatAnthropic
from langgraph.prebuilt import ToolNode

from typing import Literal, Optional
import time 

from tools import download_youtube_video, get_tools

llm_pro = ChatVertexAI(model="gemini-2.5-pro")
llm_claude = ChatAnthropic(model='claude-3-5-sonnet-latest', max_retries=6)
llm_tools = llm_claude.bind_tools(get_tools())

class TaskState(MessagesState):        # inherits the standard “messages” list
    check_final_answer: bool | None
    path_filename: str | None
    gcp_path: str | None
    final_answer: str | None
    explanation: str | None

class RouterFilename(BaseModel):
    is_filename_attached: bool = Field(..., description="Whether or not there is a file or link associated with data to be analysed at the user's request.")
    data_type: Literal["code", "data", "youtube", "audio", "image", "none"] = Field(..., description="Type of file attached to the task")
    youtube_url: Optional[str] = Field(
        default=None, 
        description="Youtube URL attached to the user's order, if any."
    )

class Answer(BaseModel):
    final_answer: Optional[str] = Field(
        default=None, 
        description="Final response for the user"
    )

    explanation: Optional[str] = Field(
        default=None, 
        description="Explanation of the final response"
    )

def attach_data(state: TaskState) -> dict:
    messages = [
        {"role": "system",
         "content": SYSTEM_PROMPT_ATTACH_FILENAME}
    ] + state["messages"]

    generator = llm_pro.with_structured_output(RouterFilename)

    for _ in range(3):          # 3 reintentos lógicos
        try:
            router_decision = generator.invoke(messages)
            if router_decision is not None:
                break
        except ValidationError as err:
            messages.append({"role": "system", "content": 
                             "This JSON is not valid! Please, try again."})
            time.sleep(2.0)
    else:
        raise RuntimeError("Gemini didn't get the structured output.")
        
    print(f"Router filename decision: {router_decision}")
    if router_decision.is_filename_attached:
        filename_type = router_decision.data_type
        if filename_type in ("code", "data"):
            path_filename = state["path_filename"]
            if filename_type == 'code':
                with open(state["path_filename"], "r", encoding="utf-8") as f:
                    code = f.read()
        
                response = f"Code:\n```python\n{code}\n```"
            else:
                response = f"Path of the attached file: {path_filename}"
                
        elif filename_type == 'youtube':
            if state.get('gcp_path'):
                gcp_path = state["gcp_path"]
            else:
                _, gcp_path = download_youtube_video(router_decision.youtube_url, "video")
                
            response = f"video GCP uri: {gcp_path}"

        elif filename_type == 'audio':
            gcp_path = state["gcp_path"]
            response = f"audio GCP uri: {gcp_path}"

        else:
            gcp_path = state["gcp_path"]
            response = f"image GCP uri: {gcp_path}"      
            

        #pdb.set_trace()
        return {"messages": state["messages"] + [response]}

    return {}

def manager(state: TaskState) -> dict:
    messages = [
        {"role": "system",
         "content": SYSTEM_PROMPT_ORQ}
    ] + state["messages"]
    
    response = llm_tools.invoke(messages)
    print(f"LLM ORQ response: {response}")

    #suponemos que esto tiene que ser la respuesta final
    if not response.tool_calls and "FINAL_ANSER" in response.content:
        return {"messages": state["messages"] + [response], "check_final_anser": True}

    return {"messages": state["messages"] + [response]}

def next_node_router(state: TaskState) -> Literal[
    "tool_node", "aggregator"
]:
    if state["check_final_answer"]:
        return "aggregator"

    # Inspeccionamos el último mensaje del historial
    last_message = state["messages"][-1]
    if isinstance(last_message, AIMessage) and last_message.tool_calls:
        return "tool_node"

    return "aggregator"

def aggregator(state: TaskState) -> dict:
    task = state["messages"][0].content
    last_model_answer = state["messages"][-1].content
    
    content = f"""
    Task: {task}
    {last_model_answer}
    """
    message_last = HumanMessage(content=content)
    
    messages = [
        {"role": "system",
         "content": SYSTEM_PROMPT_AGGREGATOR}
    ] + [message_last]

    generator = llm_pro.with_structured_output(Answer)

    for _ in range(3):          # 3 reintentos lógicos
        try:
            response = generator.invoke(messages)
            if response is not None:            # lista no vacía
                break
        except ValidationError as err:
            messages.append({"role": "system", "content": 
                             "This JSON is not valid! Please, try again."})
            time.sleep(2.0)
    else:
        raise RuntimeError("Gemini didn't get the structured output.")

    return {"final_answer": response.final_answer, "explanation": response.explanation}


def generate_graph():
    tool_node = ToolNode(get_tools()) 

    builder = StateGraph(TaskState)

    # Añadimos todos los nodos, incluyendo el nuevo tool_node
    builder.add_node("attach_data", attach_data)
    builder.add_node("manager", manager)
    builder.add_node("tool_node", tool_node) # NUEVO
    builder.add_node("aggregator", aggregator)

    # El manager es el punto de partida
    builder.add_edge(START, "attach_data")
    builder.add_edge("attach_data", "manager")

    # Después de ejecutar una herramienta, vuelve al manager con el resultado
    builder.add_edge("tool_node", "manager")

    # El manager ahora usa un enrutador condicional para decidir el siguiente gran paso
    builder.add_conditional_edges(
        "manager",
        next_node_router,
        # El mapeo ahora es más simple gracias a la lógica en next_node_router
        {
            "tool_node": "tool_node",
            "aggregator": "aggregator"
        }
    )

    graph = builder.compile()

    return graph