Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import logging | |
| from typing import Dict, List, TypedDict, Optional | |
| from .data_loader import ( | |
| load_attractions, | |
| load_wait_history, | |
| get_attraction_by_name, | |
| ) | |
| from .models import ItineraryPlan, ItineraryStop, UserRequest | |
| from .routing import greedy_schedule, order_by_nearest_neighbor | |
| from .llm import generate_narrative_logs | |
| logger = logging.getLogger("qiddiya.agents") | |
| class QiddiyaState(TypedDict, total=False): | |
| user_request: Dict | |
| logs: List[str] | |
| wait_time_forecast: Dict[str, int] | None | |
| raw_plan: Dict | None | |
| final_plan: Dict | None | |
| critique: str | None | |
| reflection_round: int | |
| refined_attraction_ids: List[str] | None # after reflection: only schedule these (in order) | |
| def _append_log(state: QiddiyaState, message: str) -> None: | |
| logs = state.get("logs", []) | |
| logs.append(message) | |
| state["logs"] = logs | |
| logger.info(message) | |
| def orchestrator_node(state: QiddiyaState) -> QiddiyaState: | |
| _append_log(state, "Orchestrator: starting planning workflow.") | |
| return state | |
| def wait_time_predictor_node(state: QiddiyaState) -> QiddiyaState: | |
| _append_log(state, "Wait-Time Agent: computing simple demand-based forecast.") | |
| history = load_wait_history() | |
| forecast: Dict[str, int] = {} | |
| for attr_id in history["attraction_id"].unique(): | |
| subset = history[history["attraction_id"] == attr_id] | |
| forecast[attr_id] = int(subset["wait_minutes"].mean()) | |
| state["wait_time_forecast"] = forecast | |
| _append_log( | |
| state, | |
| f"Wait-Time Agent: produced forecast for {len(forecast)} attractions.", | |
| ) | |
| return state | |
| def route_optimizer_node(state: QiddiyaState) -> QiddiyaState: | |
| req = UserRequest.model_validate(state["user_request"]) | |
| attractions = load_attractions() | |
| node_lookup = {a.id: a.node_id for a in attractions.values()} | |
| # If reflection trimmed the plan, only schedule those attractions (in geographic order) this pass | |
| refined = state.get("refined_attraction_ids") or [] | |
| if refined: | |
| _append_log( | |
| state, | |
| f"Route Agent: re-building schedule from refined list ({len(refined)} attractions) after reflection.", | |
| ) | |
| candidate = [aid for aid in refined if aid in attractions] | |
| ordered_ids = order_by_nearest_neighbor(candidate, node_lookup) | |
| state["refined_attraction_ids"] = [] # clear so we don't persist for future runs | |
| else: | |
| _append_log(state, "Route Agent: building schedule by geography (nearest-neighbor from hub).") | |
| must_do_ids: List[str] = [] | |
| for name in req.must_do_attractions: | |
| attr = get_attraction_by_name(name) | |
| if attr: | |
| must_do_ids.append(attr.id) | |
| # Only schedule what the user asked for: no extra attractions | |
| if must_do_ids: | |
| candidate = must_do_ids | |
| else: | |
| # No must-dos selected: suggest a few by intensity | |
| other_ids = [ | |
| a.id for a in attractions.values() | |
| if req.intensity_preference <= 3 or a.thrill_level <= 3 | |
| ] | |
| candidate = other_ids[:8] # cap suggestions when nothing selected | |
| # Order by space (minimize walking), not by user selection order | |
| ordered_ids = order_by_nearest_neighbor(candidate, node_lookup) | |
| wait_lookup = state.get("wait_time_forecast") or {} | |
| start_hour, start_minute = [int(p) for p in req.start_time.split(":")] | |
| start_minutes = start_hour * 60 + start_minute | |
| stops_raw, total_wait, total_walk = greedy_schedule( | |
| attractions_order=ordered_ids, | |
| start_time_minutes=start_minutes, | |
| walking_weight=float(6 - req.walking_tolerance), | |
| wait_time_lookup=wait_lookup, | |
| node_lookup=node_lookup, | |
| ) | |
| stops = [ItineraryStop.model_validate(s).model_dump() for s in stops_raw] | |
| enjoyment = min(10.0, float(len(stops)) * (0.5 + 0.1 * req.intensity_preference)) | |
| must_do_ids_for_coverage: List[str] = [] | |
| for name in req.must_do_attractions: | |
| attr = get_attraction_by_name(name) | |
| if attr: | |
| must_do_ids_for_coverage.append(attr.id) | |
| must_do_set = set(must_do_ids_for_coverage) | |
| for s in stops: | |
| s["is_suggested"] = s["attraction_id"] not in must_do_set | |
| coverage = float(len([s for s in stops if s.get("attraction_id") in must_do_set])) / max(len(req.must_do_attractions), 1) | |
| plan = ItineraryPlan( | |
| visit_date=req.visit_date, | |
| total_wait_minutes=total_wait, | |
| total_walking_m=total_walk, | |
| coverage_score=coverage, | |
| enjoyment_score=enjoyment, | |
| stops=[ItineraryStop.model_validate(s) for s in stops], | |
| logs=state.get("logs", []), | |
| ) | |
| state["raw_plan"] = plan.model_dump() | |
| _append_log( | |
| state, | |
| f"Route Agent: produced raw plan with {len(stops)} stops, " | |
| f"wait={total_wait}min, walk={total_walk}m.", | |
| ) | |
| return state | |
| def experience_writer_node(state: QiddiyaState) -> QiddiyaState: | |
| _append_log(state, "Guide Agent: generating visitor-friendly annotations (Groq if available).") | |
| raw_plan = ItineraryPlan.model_validate(state["raw_plan"]) | |
| req = UserRequest.model_validate(state["user_request"]) | |
| narrative_logs = generate_narrative_logs(raw_plan, req) | |
| annotated_logs = list(state.get("logs", [])) + narrative_logs | |
| final_plan = raw_plan.model_copy(update={"logs": annotated_logs}) | |
| state["final_plan"] = final_plan.model_dump() | |
| _append_log( | |
| state, | |
| "Guide Agent: added narrative guidance to itinerary logs.", | |
| ) | |
| return state | |
| def critic_node(state: QiddiyaState) -> QiddiyaState: | |
| """ | |
| Critic Agent: validates the plan against fixed criteria. | |
| Reflection runs only when at least one violation is found and reflection_round < 2. | |
| Criteria (violations): | |
| 1. Must-do coverage: every requested must-do attraction must appear in the plan. | |
| 2. Time window: total_wait_minutes must not exceed (end_time - start_time). | |
| 3. Walking tolerance: total_walking_m must not exceed max_walk | |
| (max_walk = 3500 + 800 * (walking_tolerance - 1)). | |
| 4. Enjoyment: enjoyment_score must be at least 3.0 (triggers reflection to add nearby stops). | |
| """ | |
| _append_log(state, "Critic Agent: validating constraints and safety.") | |
| plan = ItineraryPlan.model_validate(state["final_plan"]) | |
| req = UserRequest.model_validate(state["user_request"]) | |
| violations: List[str] = [] | |
| must_do_ids: List[str] = [] | |
| for name in req.must_do_attractions: | |
| attr = get_attraction_by_name(name) | |
| if attr: | |
| must_do_ids.append(attr.id) | |
| covered = {s.attraction_id for s in plan.stops} | |
| missing_must_do = [mid for mid in must_do_ids if mid not in covered] | |
| if missing_must_do: | |
| n = len(missing_must_do) | |
| violations.append( | |
| f"{n} must-do attraction(s) are not scheduled." | |
| if n > 1 | |
| else "One must-do attraction is not scheduled." | |
| ) | |
| window_minutes = ( | |
| int(req.end_time.split(":")[0]) * 60 | |
| + int(req.end_time.split(":")[1]) | |
| - int(req.start_time.split(":")[0]) * 60 | |
| - int(req.start_time.split(":")[1]) | |
| ) | |
| if plan.total_wait_minutes > window_minutes: | |
| violations.append("Total wait time exceeds visit window.") | |
| # Walking cap: total walking must not exceed tolerance-based limit | |
| max_walk = 3500 + 800 * (req.walking_tolerance - 1) | |
| if plan.total_walking_m > max_walk: | |
| violations.append("Planned walking distance exceeds user tolerance.") | |
| # Low enjoyment: trigger reflection so we can suggest adding nearby optional stops | |
| if plan.enjoyment_score < 3.0: | |
| violations.append("Enjoyment score is low.") | |
| reflection_round = int(state.get("reflection_round", 0)) | |
| if violations and reflection_round < 2: | |
| _append_log( | |
| state, | |
| f"Critic Agent: found violations, triggering reflection round {reflection_round + 1}.", | |
| ) | |
| state["critique"] = "; ".join(violations) | |
| state["reflection_round"] = reflection_round + 1 | |
| else: | |
| if violations: | |
| _append_log( | |
| state, | |
| "Critic Agent: violations remain but reflection limit reached.", | |
| ) | |
| else: | |
| _append_log(state, "Critic Agent: plan accepted with no violations.") | |
| state["critique"] = "; ".join(violations) if violations else "" | |
| return state | |
| def reflection_node(state: QiddiyaState) -> QiddiyaState: | |
| """ | |
| Reflection: adjusts the plan based on the Critic's critique. | |
| - If the critique mentions "walking"/"distance": trim optional long-walk stops. | |
| - If the critique mentions "enjoyment": add up to 2 nearby optional attractions to boost the score. | |
| """ | |
| critique: Optional[str] = state.get("critique") | |
| if not critique: | |
| _append_log(state, "Orchestrator: no critique to address, skipping reflection.") | |
| return state | |
| _append_log( | |
| state, | |
| f"Orchestrator: applying heuristic reflection based on critique: {critique}", | |
| ) | |
| plan = ItineraryPlan.model_validate(state["final_plan"]) | |
| req = UserRequest.model_validate(state["user_request"]) | |
| attractions = load_attractions() | |
| node_lookup = {a.id: a.node_id for a in attractions.values()} | |
| c_lower = critique.lower() | |
| must_do_ids_set: set = set() | |
| for name in req.must_do_attractions: | |
| attr = get_attraction_by_name(name) | |
| if attr: | |
| must_do_ids_set.add(attr.id) | |
| # --- Enjoyment: add 1–2 nearby optional stops to boost the plan --- | |
| if "enjoyment" in c_lower and len(plan.stops) < 6: | |
| from .routing import shortest_path_distance | |
| current_ids = [s.attraction_id for s in plan.stops] | |
| last_node = plan.stops[-1].node_id if plan.stops else "HUB" | |
| optional_ids = [ | |
| a.id for a in attractions.values() | |
| if a.id not in current_ids | |
| and (req.intensity_preference > 3 or a.thrill_level <= 3) | |
| ] | |
| # Nearest optional attractions from the last stop (up to 2) | |
| by_dist = sorted( | |
| optional_ids, | |
| key=lambda aid: shortest_path_distance(last_node, node_lookup[aid]), | |
| ) | |
| # Add enough so we reach at least 5 stops (enjoyment >= 3.0 for low intensity) | |
| need = max(0, 5 - len(plan.stops)) | |
| add_count = min(3, need) # add up to 3, cap total at 6 | |
| to_add = by_dist[: add_count] if add_count > 0 else [] | |
| if to_add: | |
| refined_ids = current_ids + to_add | |
| state["refined_attraction_ids"] = refined_ids | |
| state["raw_plan"] = plan.model_dump() | |
| _append_log( | |
| state, | |
| f"Orchestrator: reflection adding {len(to_add)} nearby attraction(s) to boost enjoyment.", | |
| ) | |
| else: | |
| state["refined_attraction_ids"] = current_ids | |
| _append_log(state, "Orchestrator: no optional attractions to add for enjoyment.") | |
| return state | |
| # --- Walking: trim optional long-walk stops --- | |
| filtered_stops: List[ItineraryStop] = [] | |
| for stop in plan.stops: | |
| if stop.attraction_id in must_do_ids_set: | |
| filtered_stops.append(stop) | |
| continue | |
| if ( | |
| ("walking" in c_lower or "distance" in c_lower) | |
| and req.walking_tolerance <= 2 | |
| and stop.walking_distance_m > 300 | |
| ): | |
| continue | |
| filtered_stops.append(stop) | |
| final_plan = plan.model_copy(update={"stops": filtered_stops}) | |
| state["raw_plan"] = final_plan.model_dump() | |
| state["refined_attraction_ids"] = [s.attraction_id for s in filtered_stops] | |
| _append_log( | |
| state, | |
| f"Orchestrator: reflection adjusted plan to {len(filtered_stops)} stops (removed long-walk stops).", | |
| ) | |
| return state | |