Spaces:
Sleeping
Sleeping
File size: 12,201 Bytes
214f910 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 | from __future__ import annotations
import logging
from typing import Dict, List, TypedDict, Optional
from .data_loader import (
load_attractions,
load_wait_history,
get_attraction_by_name,
)
from .models import ItineraryPlan, ItineraryStop, UserRequest
from .routing import greedy_schedule, order_by_nearest_neighbor
from .llm import generate_narrative_logs
logger = logging.getLogger("qiddiya.agents")
class QiddiyaState(TypedDict, total=False):
user_request: Dict
logs: List[str]
wait_time_forecast: Dict[str, int] | None
raw_plan: Dict | None
final_plan: Dict | None
critique: str | None
reflection_round: int
refined_attraction_ids: List[str] | None # after reflection: only schedule these (in order)
def _append_log(state: QiddiyaState, message: str) -> None:
logs = state.get("logs", [])
logs.append(message)
state["logs"] = logs
logger.info(message)
def orchestrator_node(state: QiddiyaState) -> QiddiyaState:
_append_log(state, "Orchestrator: starting planning workflow.")
return state
def wait_time_predictor_node(state: QiddiyaState) -> QiddiyaState:
_append_log(state, "Wait-Time Agent: computing simple demand-based forecast.")
history = load_wait_history()
forecast: Dict[str, int] = {}
for attr_id in history["attraction_id"].unique():
subset = history[history["attraction_id"] == attr_id]
forecast[attr_id] = int(subset["wait_minutes"].mean())
state["wait_time_forecast"] = forecast
_append_log(
state,
f"Wait-Time Agent: produced forecast for {len(forecast)} attractions.",
)
return state
def route_optimizer_node(state: QiddiyaState) -> QiddiyaState:
req = UserRequest.model_validate(state["user_request"])
attractions = load_attractions()
node_lookup = {a.id: a.node_id for a in attractions.values()}
# If reflection trimmed the plan, only schedule those attractions (in geographic order) this pass
refined = state.get("refined_attraction_ids") or []
if refined:
_append_log(
state,
f"Route Agent: re-building schedule from refined list ({len(refined)} attractions) after reflection.",
)
candidate = [aid for aid in refined if aid in attractions]
ordered_ids = order_by_nearest_neighbor(candidate, node_lookup)
state["refined_attraction_ids"] = [] # clear so we don't persist for future runs
else:
_append_log(state, "Route Agent: building schedule by geography (nearest-neighbor from hub).")
must_do_ids: List[str] = []
for name in req.must_do_attractions:
attr = get_attraction_by_name(name)
if attr:
must_do_ids.append(attr.id)
# Only schedule what the user asked for: no extra attractions
if must_do_ids:
candidate = must_do_ids
else:
# No must-dos selected: suggest a few by intensity
other_ids = [
a.id for a in attractions.values()
if req.intensity_preference <= 3 or a.thrill_level <= 3
]
candidate = other_ids[:8] # cap suggestions when nothing selected
# Order by space (minimize walking), not by user selection order
ordered_ids = order_by_nearest_neighbor(candidate, node_lookup)
wait_lookup = state.get("wait_time_forecast") or {}
start_hour, start_minute = [int(p) for p in req.start_time.split(":")]
start_minutes = start_hour * 60 + start_minute
stops_raw, total_wait, total_walk = greedy_schedule(
attractions_order=ordered_ids,
start_time_minutes=start_minutes,
walking_weight=float(6 - req.walking_tolerance),
wait_time_lookup=wait_lookup,
node_lookup=node_lookup,
)
stops = [ItineraryStop.model_validate(s).model_dump() for s in stops_raw]
enjoyment = min(10.0, float(len(stops)) * (0.5 + 0.1 * req.intensity_preference))
must_do_ids_for_coverage: List[str] = []
for name in req.must_do_attractions:
attr = get_attraction_by_name(name)
if attr:
must_do_ids_for_coverage.append(attr.id)
must_do_set = set(must_do_ids_for_coverage)
for s in stops:
s["is_suggested"] = s["attraction_id"] not in must_do_set
coverage = float(len([s for s in stops if s.get("attraction_id") in must_do_set])) / max(len(req.must_do_attractions), 1)
plan = ItineraryPlan(
visit_date=req.visit_date,
total_wait_minutes=total_wait,
total_walking_m=total_walk,
coverage_score=coverage,
enjoyment_score=enjoyment,
stops=[ItineraryStop.model_validate(s) for s in stops],
logs=state.get("logs", []),
)
state["raw_plan"] = plan.model_dump()
_append_log(
state,
f"Route Agent: produced raw plan with {len(stops)} stops, "
f"wait={total_wait}min, walk={total_walk}m.",
)
return state
def experience_writer_node(state: QiddiyaState) -> QiddiyaState:
_append_log(state, "Guide Agent: generating visitor-friendly annotations (Groq if available).")
raw_plan = ItineraryPlan.model_validate(state["raw_plan"])
req = UserRequest.model_validate(state["user_request"])
narrative_logs = generate_narrative_logs(raw_plan, req)
annotated_logs = list(state.get("logs", [])) + narrative_logs
final_plan = raw_plan.model_copy(update={"logs": annotated_logs})
state["final_plan"] = final_plan.model_dump()
_append_log(
state,
"Guide Agent: added narrative guidance to itinerary logs.",
)
return state
def critic_node(state: QiddiyaState) -> QiddiyaState:
"""
Critic Agent: validates the plan against fixed criteria.
Reflection runs only when at least one violation is found and reflection_round < 2.
Criteria (violations):
1. Must-do coverage: every requested must-do attraction must appear in the plan.
2. Time window: total_wait_minutes must not exceed (end_time - start_time).
3. Walking tolerance: total_walking_m must not exceed max_walk
(max_walk = 3500 + 800 * (walking_tolerance - 1)).
4. Enjoyment: enjoyment_score must be at least 3.0 (triggers reflection to add nearby stops).
"""
_append_log(state, "Critic Agent: validating constraints and safety.")
plan = ItineraryPlan.model_validate(state["final_plan"])
req = UserRequest.model_validate(state["user_request"])
violations: List[str] = []
must_do_ids: List[str] = []
for name in req.must_do_attractions:
attr = get_attraction_by_name(name)
if attr:
must_do_ids.append(attr.id)
covered = {s.attraction_id for s in plan.stops}
missing_must_do = [mid for mid in must_do_ids if mid not in covered]
if missing_must_do:
n = len(missing_must_do)
violations.append(
f"{n} must-do attraction(s) are not scheduled."
if n > 1
else "One must-do attraction is not scheduled."
)
window_minutes = (
int(req.end_time.split(":")[0]) * 60
+ int(req.end_time.split(":")[1])
- int(req.start_time.split(":")[0]) * 60
- int(req.start_time.split(":")[1])
)
if plan.total_wait_minutes > window_minutes:
violations.append("Total wait time exceeds visit window.")
# Walking cap: total walking must not exceed tolerance-based limit
max_walk = 3500 + 800 * (req.walking_tolerance - 1)
if plan.total_walking_m > max_walk:
violations.append("Planned walking distance exceeds user tolerance.")
# Low enjoyment: trigger reflection so we can suggest adding nearby optional stops
if plan.enjoyment_score < 3.0:
violations.append("Enjoyment score is low.")
reflection_round = int(state.get("reflection_round", 0))
if violations and reflection_round < 2:
_append_log(
state,
f"Critic Agent: found violations, triggering reflection round {reflection_round + 1}.",
)
state["critique"] = "; ".join(violations)
state["reflection_round"] = reflection_round + 1
else:
if violations:
_append_log(
state,
"Critic Agent: violations remain but reflection limit reached.",
)
else:
_append_log(state, "Critic Agent: plan accepted with no violations.")
state["critique"] = "; ".join(violations) if violations else ""
return state
def reflection_node(state: QiddiyaState) -> QiddiyaState:
"""
Reflection: adjusts the plan based on the Critic's critique.
- If the critique mentions "walking"/"distance": trim optional long-walk stops.
- If the critique mentions "enjoyment": add up to 2 nearby optional attractions to boost the score.
"""
critique: Optional[str] = state.get("critique")
if not critique:
_append_log(state, "Orchestrator: no critique to address, skipping reflection.")
return state
_append_log(
state,
f"Orchestrator: applying heuristic reflection based on critique: {critique}",
)
plan = ItineraryPlan.model_validate(state["final_plan"])
req = UserRequest.model_validate(state["user_request"])
attractions = load_attractions()
node_lookup = {a.id: a.node_id for a in attractions.values()}
c_lower = critique.lower()
must_do_ids_set: set = set()
for name in req.must_do_attractions:
attr = get_attraction_by_name(name)
if attr:
must_do_ids_set.add(attr.id)
# --- Enjoyment: add 1–2 nearby optional stops to boost the plan ---
if "enjoyment" in c_lower and len(plan.stops) < 6:
from .routing import shortest_path_distance
current_ids = [s.attraction_id for s in plan.stops]
last_node = plan.stops[-1].node_id if plan.stops else "HUB"
optional_ids = [
a.id for a in attractions.values()
if a.id not in current_ids
and (req.intensity_preference > 3 or a.thrill_level <= 3)
]
# Nearest optional attractions from the last stop (up to 2)
by_dist = sorted(
optional_ids,
key=lambda aid: shortest_path_distance(last_node, node_lookup[aid]),
)
# Add enough so we reach at least 5 stops (enjoyment >= 3.0 for low intensity)
need = max(0, 5 - len(plan.stops))
add_count = min(3, need) # add up to 3, cap total at 6
to_add = by_dist[: add_count] if add_count > 0 else []
if to_add:
refined_ids = current_ids + to_add
state["refined_attraction_ids"] = refined_ids
state["raw_plan"] = plan.model_dump()
_append_log(
state,
f"Orchestrator: reflection adding {len(to_add)} nearby attraction(s) to boost enjoyment.",
)
else:
state["refined_attraction_ids"] = current_ids
_append_log(state, "Orchestrator: no optional attractions to add for enjoyment.")
return state
# --- Walking: trim optional long-walk stops ---
filtered_stops: List[ItineraryStop] = []
for stop in plan.stops:
if stop.attraction_id in must_do_ids_set:
filtered_stops.append(stop)
continue
if (
("walking" in c_lower or "distance" in c_lower)
and req.walking_tolerance <= 2
and stop.walking_distance_m > 300
):
continue
filtered_stops.append(stop)
final_plan = plan.model_copy(update={"stops": filtered_stops})
state["raw_plan"] = final_plan.model_dump()
state["refined_attraction_ids"] = [s.attraction_id for s in filtered_stops]
_append_log(
state,
f"Orchestrator: reflection adjusted plan to {len(filtered_stops)} stops (removed long-walk stops).",
)
return state
|