File size: 5,965 Bytes
ffda755
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from __future__ import annotations

from typing import Any

from .models import Objective, RouteMatrix, RouteSolution


class SolverError(RuntimeError):
    pass


def solve_route(
    matrix: RouteMatrix,
    objective: Objective,
    return_to_start: bool,
    fixed_end_place: str | None = None,
) -> RouteSolution:
    if len(matrix.points) < 2:
        raise SolverError("至少需要起点和 1 个目的地。")

    cost_matrix = matrix.durations if objective == "time" else matrix.distances
    fixed_end_idx = find_fixed_end_index(matrix, fixed_end_place)
    route_indices, objective_cost = held_karp(
        cost_matrix=cost_matrix,
        return_to_start=return_to_start,
        fixed_end_idx=fixed_end_idx,
    )

    total_duration = route_cost(matrix.durations, route_indices)
    total_distance = route_cost(matrix.distances, route_indices)
    route_names = [matrix.points[idx].name for idx in route_indices]
    leg_rows = build_leg_rows(matrix, route_indices)

    if len(matrix.points) <= 10:
        algorithm = "Held-Karp dynamic programming exact TSP solver"
    else:
        algorithm = "Nearest-neighbor + 2-opt heuristic"

    if objective_cost < 0:
        raise SolverError("求解失败:目标函数成本异常。")

    return RouteSolution(
        route_indices=route_indices,
        route_names=route_names,
        total_duration_seconds=total_duration,
        total_distance_meters=total_distance,
        objective=objective,
        algorithm=algorithm,
        leg_rows=leg_rows,
    )


def held_karp(
    cost_matrix: list[list[float]],
    return_to_start: bool,
    fixed_end_idx: int | None,
) -> tuple[list[int], float]:
    node_count = len(cost_matrix)
    if node_count == 1:
        return [0], 0.0

    if return_to_start:
        fixed_end_idx = None

    all_destinations = list(range(1, node_count))
    visit_nodes = [idx for idx in all_destinations if idx != fixed_end_idx]

    if not visit_nodes:
        if fixed_end_idx is None:
            return [0, 0] if return_to_start else [0], 0.0
        return [0, fixed_end_idx], safe_cost(cost_matrix, 0, fixed_end_idx)

    bit_for_node = {node: 1 << pos for pos, node in enumerate(visit_nodes)}
    dp: dict[tuple[int, int], tuple[float, int | None]] = {}

    for node in visit_nodes:
        mask = bit_for_node[node]
        dp[(mask, node)] = (safe_cost(cost_matrix, 0, node), None)

    full_mask = (1 << len(visit_nodes)) - 1
    for mask in range(1, full_mask + 1):
        for last in visit_nodes:
            if not mask & bit_for_node[last]:
                continue
            current = dp.get((mask, last))
            if current is None:
                continue
            current_cost, _ = current
            for nxt in visit_nodes:
                nxt_bit = bit_for_node[nxt]
                if mask & nxt_bit:
                    continue
                next_mask = mask | nxt_bit
                next_cost = current_cost + safe_cost(cost_matrix, last, nxt)
                old = dp.get((next_mask, nxt))
                if old is None or next_cost < old[0]:
                    dp[(next_mask, nxt)] = (next_cost, last)

    best_last: int | None = None
    best_cost = float("inf")
    for last in visit_nodes:
        state = dp.get((full_mask, last))
        if state is None:
            continue
        total = state[0]
        if fixed_end_idx is not None:
            total += safe_cost(cost_matrix, last, fixed_end_idx)
        elif return_to_start:
            total += safe_cost(cost_matrix, last, 0)
        if total < best_cost:
            best_cost = total
            best_last = last

    if best_last is None:
        raise SolverError("Held-Karp 求解失败。")

    path = reconstruct_path(dp, bit_for_node, full_mask, best_last)
    route = [0] + path
    if fixed_end_idx is not None:
        route.append(fixed_end_idx)
    elif return_to_start:
        route.append(0)
    return route, best_cost


def reconstruct_path(
    dp: dict[tuple[int, int], tuple[float, int | None]],
    bit_for_node: dict[int, int],
    mask: int,
    last: int,
) -> list[int]:
    reversed_path: list[int] = []
    current_last: int | None = last
    current_mask = mask
    while current_last is not None:
        reversed_path.append(current_last)
        _, prev = dp[(current_mask, current_last)]
        current_mask ^= bit_for_node[current_last]
        current_last = prev
    return list(reversed(reversed_path))


def find_fixed_end_index(matrix: RouteMatrix, fixed_end_place: str | None) -> int | None:
    if not fixed_end_place:
        return None
    normalized = fixed_end_place.strip().lower()
    for idx, point in enumerate(matrix.points):
        if idx == 0:
            continue
        if normalized in point.name.lower() or point.name.lower() in normalized:
            return idx
    return None


def route_cost(cost_matrix: list[list[float]], route_indices: list[int]) -> float:
    total = 0.0
    for origin, dest in zip(route_indices, route_indices[1:]):
        total += safe_cost(cost_matrix, origin, dest)
    return total


def safe_cost(cost_matrix: list[list[float]], origin: int, dest: int) -> float:
    value = cost_matrix[origin][dest]
    if value is None:
        return float("inf")
    return float(value)


def build_leg_rows(matrix: RouteMatrix, route_indices: list[int]) -> list[list[Any]]:
    rows: list[list[Any]] = []
    for step, (origin, dest) in enumerate(zip(route_indices, route_indices[1:]), start=1):
        rows.append(
            [
                step,
                matrix.points[origin].name,
                matrix.points[dest].name,
                format_km(matrix.distances[origin][dest]),
                format_minutes(matrix.durations[origin][dest]),
            ]
        )
    return rows


def format_km(meters: float) -> str:
    return f"{meters / 1000:.2f} km"


def format_minutes(seconds: float) -> str:
    return f"{seconds / 60:.1f} min"