File size: 12,201 Bytes
214f910
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
from __future__ import annotations

import logging
from typing import Dict, List, TypedDict, Optional

from .data_loader import (
    load_attractions,
    load_wait_history,
    get_attraction_by_name,
)
from .models import ItineraryPlan, ItineraryStop, UserRequest
from .routing import greedy_schedule, order_by_nearest_neighbor
from .llm import generate_narrative_logs


logger = logging.getLogger("qiddiya.agents")


class QiddiyaState(TypedDict, total=False):
    user_request: Dict
    logs: List[str]
    wait_time_forecast: Dict[str, int] | None
    raw_plan: Dict | None
    final_plan: Dict | None
    critique: str | None
    reflection_round: int
    refined_attraction_ids: List[str] | None  # after reflection: only schedule these (in order)


def _append_log(state: QiddiyaState, message: str) -> None:
    logs = state.get("logs", [])
    logs.append(message)
    state["logs"] = logs
    logger.info(message)


def orchestrator_node(state: QiddiyaState) -> QiddiyaState:
    _append_log(state, "Orchestrator: starting planning workflow.")
    return state


def wait_time_predictor_node(state: QiddiyaState) -> QiddiyaState:
    _append_log(state, "Wait-Time Agent: computing simple demand-based forecast.")
    history = load_wait_history()
    forecast: Dict[str, int] = {}
    for attr_id in history["attraction_id"].unique():
        subset = history[history["attraction_id"] == attr_id]
        forecast[attr_id] = int(subset["wait_minutes"].mean())
    state["wait_time_forecast"] = forecast
    _append_log(
        state,
        f"Wait-Time Agent: produced forecast for {len(forecast)} attractions.",
    )
    return state


def route_optimizer_node(state: QiddiyaState) -> QiddiyaState:
    req = UserRequest.model_validate(state["user_request"])
    attractions = load_attractions()

    node_lookup = {a.id: a.node_id for a in attractions.values()}

    # If reflection trimmed the plan, only schedule those attractions (in geographic order) this pass
    refined = state.get("refined_attraction_ids") or []
    if refined:
        _append_log(
            state,
            f"Route Agent: re-building schedule from refined list ({len(refined)} attractions) after reflection.",
        )
        candidate = [aid for aid in refined if aid in attractions]
        ordered_ids = order_by_nearest_neighbor(candidate, node_lookup)
        state["refined_attraction_ids"] = []  # clear so we don't persist for future runs
    else:
        _append_log(state, "Route Agent: building schedule by geography (nearest-neighbor from hub).")
        must_do_ids: List[str] = []
        for name in req.must_do_attractions:
            attr = get_attraction_by_name(name)
            if attr:
                must_do_ids.append(attr.id)

        # Only schedule what the user asked for: no extra attractions
        if must_do_ids:
            candidate = must_do_ids
        else:
            # No must-dos selected: suggest a few by intensity
            other_ids = [
                a.id for a in attractions.values()
                if req.intensity_preference <= 3 or a.thrill_level <= 3
            ]
            candidate = other_ids[:8]  # cap suggestions when nothing selected

        # Order by space (minimize walking), not by user selection order
        ordered_ids = order_by_nearest_neighbor(candidate, node_lookup)
    wait_lookup = state.get("wait_time_forecast") or {}

    start_hour, start_minute = [int(p) for p in req.start_time.split(":")]
    start_minutes = start_hour * 60 + start_minute

    stops_raw, total_wait, total_walk = greedy_schedule(
        attractions_order=ordered_ids,
        start_time_minutes=start_minutes,
        walking_weight=float(6 - req.walking_tolerance),
        wait_time_lookup=wait_lookup,
        node_lookup=node_lookup,
    )

    stops = [ItineraryStop.model_validate(s).model_dump() for s in stops_raw]
    enjoyment = min(10.0, float(len(stops)) * (0.5 + 0.1 * req.intensity_preference))
    must_do_ids_for_coverage: List[str] = []
    for name in req.must_do_attractions:
        attr = get_attraction_by_name(name)
        if attr:
            must_do_ids_for_coverage.append(attr.id)
    must_do_set = set(must_do_ids_for_coverage)
    for s in stops:
        s["is_suggested"] = s["attraction_id"] not in must_do_set
    coverage = float(len([s for s in stops if s.get("attraction_id") in must_do_set])) / max(len(req.must_do_attractions), 1)

    plan = ItineraryPlan(
        visit_date=req.visit_date,
        total_wait_minutes=total_wait,
        total_walking_m=total_walk,
        coverage_score=coverage,
        enjoyment_score=enjoyment,
        stops=[ItineraryStop.model_validate(s) for s in stops],
        logs=state.get("logs", []),
    )
    state["raw_plan"] = plan.model_dump()
    _append_log(
        state,
        f"Route Agent: produced raw plan with {len(stops)} stops, "
        f"wait={total_wait}min, walk={total_walk}m.",
    )
    return state


def experience_writer_node(state: QiddiyaState) -> QiddiyaState:
    _append_log(state, "Guide Agent: generating visitor-friendly annotations (Groq if available).")
    raw_plan = ItineraryPlan.model_validate(state["raw_plan"])
    req = UserRequest.model_validate(state["user_request"])

    narrative_logs = generate_narrative_logs(raw_plan, req)
    annotated_logs = list(state.get("logs", [])) + narrative_logs
    final_plan = raw_plan.model_copy(update={"logs": annotated_logs})

    state["final_plan"] = final_plan.model_dump()
    _append_log(
        state,
        "Guide Agent: added narrative guidance to itinerary logs.",
    )
    return state


def critic_node(state: QiddiyaState) -> QiddiyaState:
    """

    Critic Agent: validates the plan against fixed criteria.

    Reflection runs only when at least one violation is found and reflection_round < 2.



    Criteria (violations):

    1. Must-do coverage: every requested must-do attraction must appear in the plan.

    2. Time window: total_wait_minutes must not exceed (end_time - start_time).

    3. Walking tolerance: total_walking_m must not exceed max_walk

       (max_walk = 3500 + 800 * (walking_tolerance - 1)).

    4. Enjoyment: enjoyment_score must be at least 3.0 (triggers reflection to add nearby stops).

    """
    _append_log(state, "Critic Agent: validating constraints and safety.")
    plan = ItineraryPlan.model_validate(state["final_plan"])
    req = UserRequest.model_validate(state["user_request"])

    violations: List[str] = []
    must_do_ids: List[str] = []

    for name in req.must_do_attractions:
        attr = get_attraction_by_name(name)
        if attr:
            must_do_ids.append(attr.id)
    covered = {s.attraction_id for s in plan.stops}
    missing_must_do = [mid for mid in must_do_ids if mid not in covered]
    if missing_must_do:
        n = len(missing_must_do)
        violations.append(
            f"{n} must-do attraction(s) are not scheduled."
            if n > 1
            else "One must-do attraction is not scheduled."
        )

    window_minutes = (
        int(req.end_time.split(":")[0]) * 60
        + int(req.end_time.split(":")[1])
        - int(req.start_time.split(":")[0]) * 60
        - int(req.start_time.split(":")[1])
    )
    if plan.total_wait_minutes > window_minutes:
        violations.append("Total wait time exceeds visit window.")

    # Walking cap: total walking must not exceed tolerance-based limit
    max_walk = 3500 + 800 * (req.walking_tolerance - 1)
    if plan.total_walking_m > max_walk:
        violations.append("Planned walking distance exceeds user tolerance.")

    # Low enjoyment: trigger reflection so we can suggest adding nearby optional stops
    if plan.enjoyment_score < 3.0:
        violations.append("Enjoyment score is low.")

    reflection_round = int(state.get("reflection_round", 0))
    if violations and reflection_round < 2:
        _append_log(
            state,
            f"Critic Agent: found violations, triggering reflection round {reflection_round + 1}.",
        )
        state["critique"] = "; ".join(violations)
        state["reflection_round"] = reflection_round + 1
    else:
        if violations:
            _append_log(
                state,
                "Critic Agent: violations remain but reflection limit reached.",
            )
        else:
            _append_log(state, "Critic Agent: plan accepted with no violations.")
        state["critique"] = "; ".join(violations) if violations else ""
    return state


def reflection_node(state: QiddiyaState) -> QiddiyaState:
    """

    Reflection: adjusts the plan based on the Critic's critique.

    - If the critique mentions "walking"/"distance": trim optional long-walk stops.

    - If the critique mentions "enjoyment": add up to 2 nearby optional attractions to boost the score.

    """
    critique: Optional[str] = state.get("critique")
    if not critique:
        _append_log(state, "Orchestrator: no critique to address, skipping reflection.")
        return state
    _append_log(
        state,
        f"Orchestrator: applying heuristic reflection based on critique: {critique}",
    )

    plan = ItineraryPlan.model_validate(state["final_plan"])
    req = UserRequest.model_validate(state["user_request"])
    attractions = load_attractions()
    node_lookup = {a.id: a.node_id for a in attractions.values()}
    c_lower = critique.lower()

    must_do_ids_set: set = set()
    for name in req.must_do_attractions:
        attr = get_attraction_by_name(name)
        if attr:
            must_do_ids_set.add(attr.id)

    # --- Enjoyment: add 1–2 nearby optional stops to boost the plan ---
    if "enjoyment" in c_lower and len(plan.stops) < 6:
        from .routing import shortest_path_distance

        current_ids = [s.attraction_id for s in plan.stops]
        last_node = plan.stops[-1].node_id if plan.stops else "HUB"
        optional_ids = [
            a.id for a in attractions.values()
            if a.id not in current_ids
            and (req.intensity_preference > 3 or a.thrill_level <= 3)
        ]
        # Nearest optional attractions from the last stop (up to 2)
        by_dist = sorted(
            optional_ids,
            key=lambda aid: shortest_path_distance(last_node, node_lookup[aid]),
        )
        # Add enough so we reach at least 5 stops (enjoyment >= 3.0 for low intensity)
        need = max(0, 5 - len(plan.stops))
        add_count = min(3, need)  # add up to 3, cap total at 6
        to_add = by_dist[: add_count] if add_count > 0 else []
        if to_add:
            refined_ids = current_ids + to_add
            state["refined_attraction_ids"] = refined_ids
            state["raw_plan"] = plan.model_dump()
            _append_log(
                state,
                f"Orchestrator: reflection adding {len(to_add)} nearby attraction(s) to boost enjoyment.",
            )
        else:
            state["refined_attraction_ids"] = current_ids
            _append_log(state, "Orchestrator: no optional attractions to add for enjoyment.")
        return state

    # --- Walking: trim optional long-walk stops ---
    filtered_stops: List[ItineraryStop] = []
    for stop in plan.stops:
        if stop.attraction_id in must_do_ids_set:
            filtered_stops.append(stop)
            continue
        if (
            ("walking" in c_lower or "distance" in c_lower)
            and req.walking_tolerance <= 2
            and stop.walking_distance_m > 300
        ):
            continue
        filtered_stops.append(stop)

    final_plan = plan.model_copy(update={"stops": filtered_stops})
    state["raw_plan"] = final_plan.model_dump()
    state["refined_attraction_ids"] = [s.attraction_id for s in filtered_stops]
    _append_log(
        state,
        f"Orchestrator: reflection adjusted plan to {len(filtered_stops)} stops (removed long-walk stops).",
    )
    return state