from __future__ import annotations import logging from typing import Dict, List, TypedDict, Optional from .data_loader import ( load_attractions, load_wait_history, get_attraction_by_name, ) from .models import ItineraryPlan, ItineraryStop, UserRequest from .routing import greedy_schedule, order_by_nearest_neighbor from .llm import generate_narrative_logs logger = logging.getLogger("qiddiya.agents") class QiddiyaState(TypedDict, total=False): user_request: Dict logs: List[str] wait_time_forecast: Dict[str, int] | None raw_plan: Dict | None final_plan: Dict | None critique: str | None reflection_round: int refined_attraction_ids: List[str] | None # after reflection: only schedule these (in order) def _append_log(state: QiddiyaState, message: str) -> None: logs = state.get("logs", []) logs.append(message) state["logs"] = logs logger.info(message) def orchestrator_node(state: QiddiyaState) -> QiddiyaState: _append_log(state, "Orchestrator: starting planning workflow.") return state def wait_time_predictor_node(state: QiddiyaState) -> QiddiyaState: _append_log(state, "Wait-Time Agent: computing simple demand-based forecast.") history = load_wait_history() forecast: Dict[str, int] = {} for attr_id in history["attraction_id"].unique(): subset = history[history["attraction_id"] == attr_id] forecast[attr_id] = int(subset["wait_minutes"].mean()) state["wait_time_forecast"] = forecast _append_log( state, f"Wait-Time Agent: produced forecast for {len(forecast)} attractions.", ) return state def route_optimizer_node(state: QiddiyaState) -> QiddiyaState: req = UserRequest.model_validate(state["user_request"]) attractions = load_attractions() node_lookup = {a.id: a.node_id for a in attractions.values()} # If reflection trimmed the plan, only schedule those attractions (in geographic order) this pass refined = state.get("refined_attraction_ids") or [] if refined: _append_log( state, f"Route Agent: re-building schedule from refined list ({len(refined)} attractions) after reflection.", ) candidate = [aid for aid in refined if aid in attractions] ordered_ids = order_by_nearest_neighbor(candidate, node_lookup) state["refined_attraction_ids"] = [] # clear so we don't persist for future runs else: _append_log(state, "Route Agent: building schedule by geography (nearest-neighbor from hub).") must_do_ids: List[str] = [] for name in req.must_do_attractions: attr = get_attraction_by_name(name) if attr: must_do_ids.append(attr.id) # Only schedule what the user asked for: no extra attractions if must_do_ids: candidate = must_do_ids else: # No must-dos selected: suggest a few by intensity other_ids = [ a.id for a in attractions.values() if req.intensity_preference <= 3 or a.thrill_level <= 3 ] candidate = other_ids[:8] # cap suggestions when nothing selected # Order by space (minimize walking), not by user selection order ordered_ids = order_by_nearest_neighbor(candidate, node_lookup) wait_lookup = state.get("wait_time_forecast") or {} start_hour, start_minute = [int(p) for p in req.start_time.split(":")] start_minutes = start_hour * 60 + start_minute stops_raw, total_wait, total_walk = greedy_schedule( attractions_order=ordered_ids, start_time_minutes=start_minutes, walking_weight=float(6 - req.walking_tolerance), wait_time_lookup=wait_lookup, node_lookup=node_lookup, ) stops = [ItineraryStop.model_validate(s).model_dump() for s in stops_raw] enjoyment = min(10.0, float(len(stops)) * (0.5 + 0.1 * req.intensity_preference)) must_do_ids_for_coverage: List[str] = [] for name in req.must_do_attractions: attr = get_attraction_by_name(name) if attr: must_do_ids_for_coverage.append(attr.id) must_do_set = set(must_do_ids_for_coverage) for s in stops: s["is_suggested"] = s["attraction_id"] not in must_do_set coverage = float(len([s for s in stops if s.get("attraction_id") in must_do_set])) / max(len(req.must_do_attractions), 1) plan = ItineraryPlan( visit_date=req.visit_date, total_wait_minutes=total_wait, total_walking_m=total_walk, coverage_score=coverage, enjoyment_score=enjoyment, stops=[ItineraryStop.model_validate(s) for s in stops], logs=state.get("logs", []), ) state["raw_plan"] = plan.model_dump() _append_log( state, f"Route Agent: produced raw plan with {len(stops)} stops, " f"wait={total_wait}min, walk={total_walk}m.", ) return state def experience_writer_node(state: QiddiyaState) -> QiddiyaState: _append_log(state, "Guide Agent: generating visitor-friendly annotations (Groq if available).") raw_plan = ItineraryPlan.model_validate(state["raw_plan"]) req = UserRequest.model_validate(state["user_request"]) narrative_logs = generate_narrative_logs(raw_plan, req) annotated_logs = list(state.get("logs", [])) + narrative_logs final_plan = raw_plan.model_copy(update={"logs": annotated_logs}) state["final_plan"] = final_plan.model_dump() _append_log( state, "Guide Agent: added narrative guidance to itinerary logs.", ) return state def critic_node(state: QiddiyaState) -> QiddiyaState: """ Critic Agent: validates the plan against fixed criteria. Reflection runs only when at least one violation is found and reflection_round < 2. Criteria (violations): 1. Must-do coverage: every requested must-do attraction must appear in the plan. 2. Time window: total_wait_minutes must not exceed (end_time - start_time). 3. Walking tolerance: total_walking_m must not exceed max_walk (max_walk = 3500 + 800 * (walking_tolerance - 1)). 4. Enjoyment: enjoyment_score must be at least 3.0 (triggers reflection to add nearby stops). """ _append_log(state, "Critic Agent: validating constraints and safety.") plan = ItineraryPlan.model_validate(state["final_plan"]) req = UserRequest.model_validate(state["user_request"]) violations: List[str] = [] must_do_ids: List[str] = [] for name in req.must_do_attractions: attr = get_attraction_by_name(name) if attr: must_do_ids.append(attr.id) covered = {s.attraction_id for s in plan.stops} missing_must_do = [mid for mid in must_do_ids if mid not in covered] if missing_must_do: n = len(missing_must_do) violations.append( f"{n} must-do attraction(s) are not scheduled." if n > 1 else "One must-do attraction is not scheduled." ) window_minutes = ( int(req.end_time.split(":")[0]) * 60 + int(req.end_time.split(":")[1]) - int(req.start_time.split(":")[0]) * 60 - int(req.start_time.split(":")[1]) ) if plan.total_wait_minutes > window_minutes: violations.append("Total wait time exceeds visit window.") # Walking cap: total walking must not exceed tolerance-based limit max_walk = 3500 + 800 * (req.walking_tolerance - 1) if plan.total_walking_m > max_walk: violations.append("Planned walking distance exceeds user tolerance.") # Low enjoyment: trigger reflection so we can suggest adding nearby optional stops if plan.enjoyment_score < 3.0: violations.append("Enjoyment score is low.") reflection_round = int(state.get("reflection_round", 0)) if violations and reflection_round < 2: _append_log( state, f"Critic Agent: found violations, triggering reflection round {reflection_round + 1}.", ) state["critique"] = "; ".join(violations) state["reflection_round"] = reflection_round + 1 else: if violations: _append_log( state, "Critic Agent: violations remain but reflection limit reached.", ) else: _append_log(state, "Critic Agent: plan accepted with no violations.") state["critique"] = "; ".join(violations) if violations else "" return state def reflection_node(state: QiddiyaState) -> QiddiyaState: """ Reflection: adjusts the plan based on the Critic's critique. - If the critique mentions "walking"/"distance": trim optional long-walk stops. - If the critique mentions "enjoyment": add up to 2 nearby optional attractions to boost the score. """ critique: Optional[str] = state.get("critique") if not critique: _append_log(state, "Orchestrator: no critique to address, skipping reflection.") return state _append_log( state, f"Orchestrator: applying heuristic reflection based on critique: {critique}", ) plan = ItineraryPlan.model_validate(state["final_plan"]) req = UserRequest.model_validate(state["user_request"]) attractions = load_attractions() node_lookup = {a.id: a.node_id for a in attractions.values()} c_lower = critique.lower() must_do_ids_set: set = set() for name in req.must_do_attractions: attr = get_attraction_by_name(name) if attr: must_do_ids_set.add(attr.id) # --- Enjoyment: add 1–2 nearby optional stops to boost the plan --- if "enjoyment" in c_lower and len(plan.stops) < 6: from .routing import shortest_path_distance current_ids = [s.attraction_id for s in plan.stops] last_node = plan.stops[-1].node_id if plan.stops else "HUB" optional_ids = [ a.id for a in attractions.values() if a.id not in current_ids and (req.intensity_preference > 3 or a.thrill_level <= 3) ] # Nearest optional attractions from the last stop (up to 2) by_dist = sorted( optional_ids, key=lambda aid: shortest_path_distance(last_node, node_lookup[aid]), ) # Add enough so we reach at least 5 stops (enjoyment >= 3.0 for low intensity) need = max(0, 5 - len(plan.stops)) add_count = min(3, need) # add up to 3, cap total at 6 to_add = by_dist[: add_count] if add_count > 0 else [] if to_add: refined_ids = current_ids + to_add state["refined_attraction_ids"] = refined_ids state["raw_plan"] = plan.model_dump() _append_log( state, f"Orchestrator: reflection adding {len(to_add)} nearby attraction(s) to boost enjoyment.", ) else: state["refined_attraction_ids"] = current_ids _append_log(state, "Orchestrator: no optional attractions to add for enjoyment.") return state # --- Walking: trim optional long-walk stops --- filtered_stops: List[ItineraryStop] = [] for stop in plan.stops: if stop.attraction_id in must_do_ids_set: filtered_stops.append(stop) continue if ( ("walking" in c_lower or "distance" in c_lower) and req.walking_tolerance <= 2 and stop.walking_distance_m > 300 ): continue filtered_stops.append(stop) final_plan = plan.model_copy(update={"stops": filtered_stops}) state["raw_plan"] = final_plan.model_dump() state["refined_attraction_ids"] = [s.attraction_id for s in filtered_stops] _append_log( state, f"Orchestrator: reflection adjusted plan to {len(filtered_stops)} stops (removed long-walk stops).", ) return state