| from __future__ import annotations |
|
|
| import json |
| from typing import Any, Callable |
|
|
| from .geo_tools import build_route_matrix, geocode_places |
| from .llm_client import LLMError, chat_completion, extract_json_from_text, extract_tool_arguments |
| from .models import AgentResult, GeoPoint, Objective, RouteMatrix, RouteSolution, RouteTask, ToolEvent |
| from .parsing import dedupe_preserve_order, heuristic_extract_task, normalize_place_lines |
| from .report import generate_pdf_report |
| from .solver import format_km, format_minutes, solve_route |
| from .viz import build_route_svg |
|
|
|
|
| TASK_TOOL_SCHEMA: dict[str, Any] = { |
| "type": "function", |
| "function": { |
| "name": "submit_route_task", |
| "description": "Extract a route optimization task from user text and structured hints.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "start_place": { |
| "type": "string", |
| "description": "The route start place. Keep it specific enough for geocoding.", |
| }, |
| "destination_places": { |
| "type": "array", |
| "items": {"type": "string"}, |
| "description": "Places that must be visited after the start.", |
| }, |
| "objective": { |
| "type": "string", |
| "enum": ["time", "distance"], |
| "description": "Optimization objective: time or distance.", |
| }, |
| "return_to_start": { |
| "type": "boolean", |
| "description": "Whether the route should return to the start place.", |
| }, |
| "fixed_end_place": { |
| "type": "string", |
| "description": "Optional fixed final destination when not returning to start.", |
| }, |
| "constraints": { |
| "type": "array", |
| "items": {"type": "string"}, |
| "description": "Soft constraints or notes mentioned by the user.", |
| }, |
| }, |
| "required": ["start_place", "destination_places", "objective", "return_to_start"], |
| }, |
| }, |
| } |
|
|
|
|
| class RouteOptAgent: |
| def __init__(self) -> None: |
| self.trace: list[ToolEvent] = [] |
|
|
| def run( |
| self, |
| raw_request: str, |
| start_hint: str = "", |
| destinations_hint: str = "", |
| objective_hint: str = "time", |
| return_to_start_hint: bool = True, |
| fixed_end_hint: str = "", |
| city_hint: str = "上海,中国", |
| use_llm: bool = True, |
| ) -> AgentResult: |
| self.trace = [] |
| warnings: list[str] = [] |
|
|
| task = self._extract_task( |
| raw_request=raw_request, |
| start_hint=start_hint, |
| destinations_hint=destinations_hint, |
| objective_hint=objective_hint, |
| return_to_start_hint=return_to_start_hint, |
| fixed_end_hint=fixed_end_hint, |
| use_llm=use_llm, |
| warnings=warnings, |
| ) |
| self._validate_task(task, warnings) |
|
|
| all_places = [task.start_place] + task.destination_places |
| points = self._record_tool( |
| "geocode_places", |
| {"places": all_places, "city_hint": city_hint}, |
| lambda: geocode_places(all_places, city_hint), |
| lambda result: f"解析 {len(result)} 个地点:{', '.join(point.name for point in result)}", |
| ) |
| matrix = self._record_tool( |
| "build_route_matrix", |
| {"points": [point.name for point in points]}, |
| lambda: build_route_matrix(points), |
| lambda result: f"获得 {len(result.points)}x{len(result.points)} 距离/时间矩阵,来源:{result.source}", |
| ) |
| solution = self._record_tool( |
| "solve_route", |
| { |
| "objective": task.objective, |
| "return_to_start": task.return_to_start, |
| "fixed_end_place": task.fixed_end_place, |
| }, |
| lambda: solve_route(matrix, task.objective, task.return_to_start, task.fixed_end_place), |
| lambda result: f"路线:{' -> '.join(result.route_names)};距离 {format_km(result.total_distance_meters)};时间 {format_minutes(result.total_duration_seconds)}", |
| ) |
|
|
| summary = self._compose_summary(task, points, matrix, solution, use_llm, warnings) |
| route_svg = build_route_svg(points, solution.route_indices) |
| pdf_path = self._record_tool( |
| "generate_pdf_report", |
| {"route": solution.route_names}, |
| lambda: generate_pdf_report(task, points, solution, self.trace, summary), |
| lambda result: f"PDF 已生成:{result}", |
| ) |
|
|
| return AgentResult( |
| task=task, |
| points=points, |
| matrix=matrix, |
| solution=solution, |
| summary_markdown=summary, |
| trace=self.trace, |
| pdf_path=pdf_path, |
| route_svg=route_svg, |
| warnings=warnings, |
| ) |
|
|
| def _extract_task( |
| self, |
| raw_request: str, |
| start_hint: str, |
| destinations_hint: str, |
| objective_hint: str, |
| return_to_start_hint: bool, |
| fixed_end_hint: str, |
| use_llm: bool, |
| warnings: list[str], |
| ) -> RouteTask: |
| if use_llm: |
| try: |
| task = self._extract_task_with_llm( |
| raw_request, |
| start_hint, |
| destinations_hint, |
| objective_hint, |
| return_to_start_hint, |
| fixed_end_hint, |
| ) |
| self._append_event( |
| "LLM tool_call: submit_route_task", |
| { |
| "raw_request": raw_request, |
| "start_hint": start_hint, |
| "destinations_hint": destinations_hint, |
| }, |
| "ok", |
| f"抽取任务:起点={task.start_place};目的地={len(task.destination_places)} 个;目标={task.objective}", |
| ) |
| return task |
| except Exception as exc: |
| warnings.append(f"LLM 工具调用抽取失败,已切换本地解析:{exc}") |
| self._append_event( |
| "LLM tool_call: submit_route_task", |
| {"raw_request": raw_request}, |
| "fallback", |
| str(exc), |
| ) |
|
|
| task = heuristic_extract_task( |
| raw_request, |
| start_hint=start_hint, |
| destinations_hint=destinations_hint, |
| objective_hint=objective_hint, |
| return_to_start_hint=return_to_start_hint, |
| fixed_end_hint=fixed_end_hint, |
| ) |
| self._append_event( |
| "local_parse_route_task", |
| { |
| "raw_request": raw_request, |
| "start_hint": start_hint, |
| "destinations_hint": destinations_hint, |
| }, |
| "ok", |
| f"本地解析任务:起点={task.start_place};目的地={len(task.destination_places)} 个;目标={task.objective}", |
| ) |
| return task |
|
|
| def _extract_task_with_llm( |
| self, |
| raw_request: str, |
| start_hint: str, |
| destinations_hint: str, |
| objective_hint: str, |
| return_to_start_hint: bool, |
| fixed_end_hint: str, |
| ) -> RouteTask: |
| messages = [ |
| { |
| "role": "system", |
| "content": ( |
| "You are a route optimization agent controller. " |
| "Use the submit_route_task function to return one structured task. " |
| "Prefer structured hints over ambiguous natural language. " |
| "Keep Chinese place names specific for geocoding. " |
| "If the user says 最快/时间最短 use objective=time; if 距离/少走路 use objective=distance." |
| ), |
| }, |
| { |
| "role": "user", |
| "content": json.dumps( |
| { |
| "raw_request": raw_request, |
| "start_hint": start_hint, |
| "destinations_hint": normalize_place_lines(destinations_hint), |
| "objective_hint": objective_hint, |
| "return_to_start_hint": return_to_start_hint, |
| "fixed_end_hint": fixed_end_hint, |
| }, |
| ensure_ascii=False, |
| ), |
| }, |
| ] |
| message = chat_completion(messages, tools=[TASK_TOOL_SCHEMA], temperature=0.1) |
| args = extract_tool_arguments(message, "submit_route_task") |
| if args is None: |
| args = extract_json_from_text(message.get("content") or "") |
| if args is None: |
| raise LLMError("模型没有返回 submit_route_task 工具参数。") |
| return self._task_from_arguments( |
| args, |
| raw_request, |
| start_hint, |
| destinations_hint, |
| objective_hint, |
| return_to_start_hint, |
| fixed_end_hint, |
| ) |
|
|
| def _task_from_arguments( |
| self, |
| args: dict[str, Any], |
| raw_request: str, |
| start_hint: str, |
| destinations_hint: str, |
| objective_hint: str, |
| return_to_start_hint: bool, |
| fixed_end_hint: str, |
| ) -> RouteTask: |
| local_task = heuristic_extract_task( |
| raw_request, |
| start_hint=start_hint, |
| destinations_hint=destinations_hint, |
| objective_hint=objective_hint, |
| return_to_start_hint=return_to_start_hint, |
| fixed_end_hint=fixed_end_hint, |
| ) |
|
|
| destinations = args.get("destination_places") or local_task.destination_places |
| if isinstance(destinations, str): |
| destinations = normalize_place_lines(destinations) |
| destinations = dedupe_preserve_order([str(item).strip() for item in destinations if str(item).strip()]) |
|
|
| start = start_hint.strip() or str(args.get("start_place") or local_task.start_place).strip() |
| objective = normalize_objective(str(args.get("objective") or local_task.objective)) |
| fixed_end = fixed_end_hint.strip() or str(args.get("fixed_end_place") or local_task.fixed_end_place or "").strip() |
|
|
| return RouteTask( |
| raw_request=raw_request, |
| start_place=start, |
| destination_places=destinations, |
| objective=objective, |
| return_to_start=return_to_start_hint if return_to_start_hint is not None else bool(args.get("return_to_start", local_task.return_to_start)), |
| fixed_end_place=fixed_end or None, |
| constraints=[str(item) for item in args.get("constraints") or local_task.constraints], |
| ) |
|
|
| def _validate_task(self, task: RouteTask, warnings: list[str]) -> None: |
| task.start_place = task.start_place.strip() |
| task.destination_places = [place.strip() for place in task.destination_places if place.strip()] |
| original_count = len(task.destination_places) |
| task.destination_places = dedupe_preserve_order(task.destination_places) |
| if len(task.destination_places) < original_count: |
| warnings.append("目的地中存在重复项,系统已自动去重。") |
|
|
| if not task.start_place: |
| raise ValueError("缺少起点。请在起点输入框填写一个地点,或在自然语言需求中写明“从哪里出发”。") |
| if not task.destination_places: |
| raise ValueError("缺少目的地。请至少填写 1 个目的地。建议每行写一个地点,例如:人民广场、外滩、陆家嘴。") |
|
|
| start_key = task.start_place.lower() |
| without_start = [ |
| place |
| for place in task.destination_places |
| if place.lower() != start_key and start_key not in place.lower() and place.lower() not in start_key |
| ] |
| if len(without_start) < len(task.destination_places): |
| warnings.append("目的地中包含起点,系统已自动移除该重复访问点。") |
| task.destination_places = without_start |
| if not task.destination_places: |
| raise ValueError("目的地里只有起点本身。请至少增加一个不同于起点的目的地。") |
|
|
| if len(task.destination_places) > 10: |
| raise ValueError( |
| "目的地数量过多。当前演示版最多支持 10 个目的地;作业视频推荐 3 到 8 个," |
| "这样 Held-Karp 精确算法和公开路线 API 都更稳定。" |
| ) |
|
|
| if task.fixed_end_place and not task.return_to_start: |
| fixed = task.fixed_end_place.strip().lower() |
| if all(fixed not in item.lower() and item.lower() not in fixed for item in task.destination_places): |
| task.destination_places.append(task.fixed_end_place) |
|
|
| def _compose_summary( |
| self, |
| task: RouteTask, |
| points: list[GeoPoint], |
| matrix: RouteMatrix, |
| solution: RouteSolution, |
| use_llm: bool, |
| warnings: list[str], |
| ) -> str: |
| if use_llm: |
| try: |
| messages = [ |
| { |
| "role": "system", |
| "content": ( |
| "你是算法课程作业里的路线优化智能体。" |
| "请用中文写一段简洁但不敷衍的求解说明,强调:" |
| "大模型负责理解和解释,工具负责地理编码/路径矩阵,算法负责最优化。" |
| "不要编造没有出现的数据。" |
| ), |
| }, |
| { |
| "role": "user", |
| "content": json.dumps( |
| { |
| "task": task.__dict__, |
| "points": [point.__dict__ for point in points], |
| "matrix_source": matrix.source, |
| "solution": { |
| "route": solution.route_names, |
| "total_distance_km": round(solution.total_distance_meters / 1000, 2), |
| "total_minutes": round(solution.total_duration_seconds / 60, 1), |
| "algorithm": solution.algorithm, |
| }, |
| "tool_trace": [event.__dict__ for event in self.trace], |
| "warnings": warnings, |
| }, |
| ensure_ascii=False, |
| ), |
| }, |
| ] |
| message = chat_completion(messages, temperature=0.35) |
| content = (message.get("content") or "").strip() |
| if content: |
| self._append_event( |
| "LLM compose_summary", |
| {"route": solution.route_names}, |
| "ok", |
| "已生成中文解释总结。", |
| ) |
| return content |
| except Exception as exc: |
| warnings.append(f"LLM 总结失败,已切换本地报告模板:{exc}") |
| self._append_event("LLM compose_summary", {}, "fallback", str(exc)) |
|
|
| return deterministic_summary(task, matrix, solution, warnings) |
|
|
| def _record_tool( |
| self, |
| name: str, |
| arguments: dict[str, Any], |
| func: Callable[[], Any], |
| summarize: Callable[[Any], str], |
| ) -> Any: |
| try: |
| result = func() |
| self._append_event(name, arguments, "ok", summarize(result)) |
| return result |
| except Exception as exc: |
| self._append_event(name, arguments, "error", str(exc)) |
| raise |
|
|
| def _append_event(self, tool: str, arguments: dict[str, Any], status: str, result: str) -> None: |
| self.trace.append( |
| ToolEvent( |
| step=len(self.trace) + 1, |
| tool=tool, |
| arguments=compact_arguments(arguments), |
| status=status, |
| result=result, |
| ) |
| ) |
|
|
|
|
| def normalize_objective(value: str) -> Objective: |
| value = value.lower().strip() |
| if value in {"distance", "最短距离", "距离"}: |
| return "distance" |
| return "time" |
|
|
|
|
| def compact_arguments(arguments: dict[str, Any]) -> dict[str, Any]: |
| compacted: dict[str, Any] = {} |
| for key, value in arguments.items(): |
| if isinstance(value, list) and len(value) > 8: |
| compacted[key] = value[:8] + ["..."] |
| else: |
| compacted[key] = value |
| return compacted |
|
|
|
|
| def deterministic_summary( |
| task: RouteTask, |
| matrix: RouteMatrix, |
| solution: RouteSolution, |
| warnings: list[str], |
| ) -> str: |
| objective_text = "预计驾驶时间" if task.objective == "time" else "路线距离" |
| warning_text = "" |
| if warnings: |
| warning_text = "\n\n**运行提示**:\n" + "\n".join(f"- {item}" for item in warnings) |
| return ( |
| "本次求解把用户需求先转成结构化路线优化任务,然后调用地理编码工具获得经纬度," |
| f"再通过 `{matrix.source}` 得到点对点距离/时间矩阵。最后,本地优化器以“{objective_text}”为目标," |
| f"使用 `{solution.algorithm}` 搜索访问顺序。\n\n" |
| f"最终路线为:**{' → '.join(solution.route_names)}**。" |
| f"总距离约 **{format_km(solution.total_distance_meters)}**,预计驾驶时间约 **{format_minutes(solution.total_duration_seconds)}**。" |
| "\n\n这个设计中,大模型不直接猜最优路线,而是负责理解自然语言、组织工具调用并解释结果;" |
| "确定性工具负责拿真实数据和完成可验证的算法计算。" |
| f"{warning_text}" |
| ) |
|
|