File size: 1,751 Bytes
214f910
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from __future__ import annotations

from typing import TypedDict, List, Dict

from langgraph.graph import StateGraph, END

from .agents import (
    QiddiyaState,
    orchestrator_node,
    wait_time_predictor_node,
    route_optimizer_node,
    experience_writer_node,
    critic_node,
    reflection_node,
)


class GraphState(TypedDict, total=False):
    user_request: Dict
    logs: List[str]
    wait_time_forecast: Dict[str, int] | None
    raw_plan: Dict | None
    final_plan: Dict | None
    critique: str | None
    reflection_round: int
    refined_attraction_ids: List[str] | None


QiddiyaState = GraphState


def _should_reflect(state: QiddiyaState) -> str:
    critique = state.get("critique") or ""
    reflection_round = int(state.get("reflection_round", 0))
    if critique and reflection_round < 2:
        return "reflect"
    return "finish"


def build_qiddiya_graph() -> StateGraph:
    graph = StateGraph(GraphState)

    graph.add_node("orchestrator", orchestrator_node)
    graph.add_node("wait_time", wait_time_predictor_node)
    graph.add_node("route", route_optimizer_node)
    graph.add_node("guide", experience_writer_node)
    graph.add_node("critic", critic_node)
    graph.add_node("reflection", reflection_node)

    graph.set_entry_point("orchestrator")
    graph.add_edge("orchestrator", "wait_time")
    graph.add_edge("wait_time", "route")
    graph.add_edge("route", "guide")
    graph.add_edge("guide", "critic")

    graph.add_conditional_edges(
        "critic",
        _should_reflect,
        {
            "reflect": "reflection",
            "finish": END,
        },
    )
    graph.add_edge("reflection", "route")

    return graph