from __future__ import annotations from typing import Any from .models import Objective, RouteMatrix, RouteSolution class SolverError(RuntimeError): pass def solve_route( matrix: RouteMatrix, objective: Objective, return_to_start: bool, fixed_end_place: str | None = None, ) -> RouteSolution: if len(matrix.points) < 2: raise SolverError("至少需要起点和 1 个目的地。") cost_matrix = matrix.durations if objective == "time" else matrix.distances fixed_end_idx = find_fixed_end_index(matrix, fixed_end_place) route_indices, objective_cost = held_karp( cost_matrix=cost_matrix, return_to_start=return_to_start, fixed_end_idx=fixed_end_idx, ) total_duration = route_cost(matrix.durations, route_indices) total_distance = route_cost(matrix.distances, route_indices) route_names = [matrix.points[idx].name for idx in route_indices] leg_rows = build_leg_rows(matrix, route_indices) if len(matrix.points) <= 10: algorithm = "Held-Karp dynamic programming exact TSP solver" else: algorithm = "Nearest-neighbor + 2-opt heuristic" if objective_cost < 0: raise SolverError("求解失败:目标函数成本异常。") return RouteSolution( route_indices=route_indices, route_names=route_names, total_duration_seconds=total_duration, total_distance_meters=total_distance, objective=objective, algorithm=algorithm, leg_rows=leg_rows, ) def held_karp( cost_matrix: list[list[float]], return_to_start: bool, fixed_end_idx: int | None, ) -> tuple[list[int], float]: node_count = len(cost_matrix) if node_count == 1: return [0], 0.0 if return_to_start: fixed_end_idx = None all_destinations = list(range(1, node_count)) visit_nodes = [idx for idx in all_destinations if idx != fixed_end_idx] if not visit_nodes: if fixed_end_idx is None: return [0, 0] if return_to_start else [0], 0.0 return [0, fixed_end_idx], safe_cost(cost_matrix, 0, fixed_end_idx) bit_for_node = {node: 1 << pos for pos, node in enumerate(visit_nodes)} dp: dict[tuple[int, int], tuple[float, int | None]] = {} for node in visit_nodes: mask = bit_for_node[node] dp[(mask, node)] = (safe_cost(cost_matrix, 0, node), None) full_mask = (1 << len(visit_nodes)) - 1 for mask in range(1, full_mask + 1): for last in visit_nodes: if not mask & bit_for_node[last]: continue current = dp.get((mask, last)) if current is None: continue current_cost, _ = current for nxt in visit_nodes: nxt_bit = bit_for_node[nxt] if mask & nxt_bit: continue next_mask = mask | nxt_bit next_cost = current_cost + safe_cost(cost_matrix, last, nxt) old = dp.get((next_mask, nxt)) if old is None or next_cost < old[0]: dp[(next_mask, nxt)] = (next_cost, last) best_last: int | None = None best_cost = float("inf") for last in visit_nodes: state = dp.get((full_mask, last)) if state is None: continue total = state[0] if fixed_end_idx is not None: total += safe_cost(cost_matrix, last, fixed_end_idx) elif return_to_start: total += safe_cost(cost_matrix, last, 0) if total < best_cost: best_cost = total best_last = last if best_last is None: raise SolverError("Held-Karp 求解失败。") path = reconstruct_path(dp, bit_for_node, full_mask, best_last) route = [0] + path if fixed_end_idx is not None: route.append(fixed_end_idx) elif return_to_start: route.append(0) return route, best_cost def reconstruct_path( dp: dict[tuple[int, int], tuple[float, int | None]], bit_for_node: dict[int, int], mask: int, last: int, ) -> list[int]: reversed_path: list[int] = [] current_last: int | None = last current_mask = mask while current_last is not None: reversed_path.append(current_last) _, prev = dp[(current_mask, current_last)] current_mask ^= bit_for_node[current_last] current_last = prev return list(reversed(reversed_path)) def find_fixed_end_index(matrix: RouteMatrix, fixed_end_place: str | None) -> int | None: if not fixed_end_place: return None normalized = fixed_end_place.strip().lower() for idx, point in enumerate(matrix.points): if idx == 0: continue if normalized in point.name.lower() or point.name.lower() in normalized: return idx return None def route_cost(cost_matrix: list[list[float]], route_indices: list[int]) -> float: total = 0.0 for origin, dest in zip(route_indices, route_indices[1:]): total += safe_cost(cost_matrix, origin, dest) return total def safe_cost(cost_matrix: list[list[float]], origin: int, dest: int) -> float: value = cost_matrix[origin][dest] if value is None: return float("inf") return float(value) def build_leg_rows(matrix: RouteMatrix, route_indices: list[int]) -> list[list[Any]]: rows: list[list[Any]] = [] for step, (origin, dest) in enumerate(zip(route_indices, route_indices[1:]), start=1): rows.append( [ step, matrix.points[origin].name, matrix.points[dest].name, format_km(matrix.distances[origin][dest]), format_minutes(matrix.durations[origin][dest]), ] ) return rows def format_km(meters: float) -> str: return f"{meters / 1000:.2f} km" def format_minutes(seconds: float) -> str: return f"{seconds / 60:.1f} min"