File size: 8,325 Bytes
b7d08cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
"""
OR-Tools 求解器 - 封裝 OR-Tools API

完全保留原始 tsptw_solver_old.py 的 OR-Tools 設置邏輯
"""
from typing import List, Dict, Any, Tuple
from datetime import datetime, timedelta

import numpy as np
from ortools.constraint_solver import routing_enums_pb2
from ortools.constraint_solver import pywrapcp

from src.infra.logger import get_logger

from src.optimization.models.internal_models import _Task, _Graph
from src.optimization.graph.time_window_handler import TimeWindowHandler

logger = get_logger(__name__)


class ORToolsSolver:
    """
    OR-Tools 求解器

    職責:
    - 創建 RoutingModel 和 RoutingIndexManager
    - 設置時間維度約束
    - 設置優先級約束
    - 執行求解
    """

    def __init__(
            self,
            time_limit_seconds: int = 30,
            verbose: bool = False,
    ):
        self.time_limit_seconds = time_limit_seconds
        self.verbose = verbose
        self.tw_handler = TimeWindowHandler()

    def solve(
            self,
            graph: _Graph,
            tasks: List[_Task],
            start_time: datetime,
            deadline: datetime,
            max_wait_time_sec: int,
    ) -> Tuple[pywrapcp.RoutingModel, pywrapcp.RoutingIndexManager, pywrapcp.Assignment]:
        """
        求解 TSPTW

        完全保留原始邏輯:
        - _solve_internal() 中的 OR-Tools 設置部分

        Returns:
            (routing, manager, solution): OR-Tools 求解結果
        """
        num_nodes = len(graph.node_meta)

        # 1. 計算服務時間
        service_time = self._build_service_time_per_node(tasks, graph.node_meta)

        # 2. 創建 manager
        manager = pywrapcp.RoutingIndexManager(num_nodes, 1, 0)

        # 3. 創建 routing model
        routing = pywrapcp.RoutingModel(manager)

        # 4. 註冊 transit callback
        transit_cb_index = self._register_transit_callback(
            routing, manager, graph.duration_matrix, service_time
        )
        routing.SetArcCostEvaluatorOfAllVehicles(transit_cb_index)

        # 5. 添加時間維度
        time_dimension = self._add_time_dimension(
            routing,
            manager,
            transit_cb_index,
            tasks,
            graph.node_meta,
            start_time,
            deadline,
            max_wait_time_sec,
        )

        # 6. 添加優先級約束
        self._add_priority_disjunctions(routing, manager, tasks, graph.node_meta)

        # 7. 設置搜索參數
        search_parameters = self._create_search_parameters()

        # 8. 求解
        if self.verbose:
            logger.info(
                "Starting OR-Tools search with time limit = %ds",
                self.time_limit_seconds,
            )

        solution = routing.SolveWithParameters(search_parameters)

        if self.verbose:
            logger.info("OR-Tools search completed")

        return routing, manager, solution

    @staticmethod
    def _build_service_time_per_node(
            tasks: List[_Task],
            node_meta: List[Dict[str, Any]],
    ) -> List[int]:
        """
        構建每個節點的服務時間(秒)

        完全保留原始邏輯: _build_service_time_per_node()
        """
        service_time = [0] * len(node_meta)

        for node, meta in enumerate(node_meta):
            if meta["type"] == "poi":
                task_idx = meta["task_idx"]
                task = tasks[task_idx]
                service_time[node] = task.service_duration_sec

        return service_time

    @staticmethod
    def _register_transit_callback(
            routing: pywrapcp.RoutingModel,
            manager: pywrapcp.RoutingIndexManager,
            duration_matrix: np.ndarray,
            service_time: List[int],
    ) -> int:
        """
        註冊 transit callback

        完全保留原始邏輯: _register_transit_callback()
        """

        def transit_callback(from_index: int, to_index: int) -> int:
            from_node = manager.IndexToNode(from_index)
            to_node = manager.IndexToNode(to_index)
            travel = duration_matrix[from_node, to_node]
            service = service_time[from_node]
            return int(travel + service)

        transit_cb_index = routing.RegisterTransitCallback(transit_callback)
        return transit_cb_index

    def _add_time_dimension(
            self,
            routing: pywrapcp.RoutingModel,
            manager: pywrapcp.RoutingIndexManager,
            transit_cb_index: int,
            tasks: List[_Task],
            node_meta: List[Dict[str, Any]],
            start_time: datetime,
            deadline: datetime,
            max_wait_time_sec: int,
    ) -> pywrapcp.RoutingDimension:
        """
        添加時間維度約束

        完全保留原始邏輯: _add_time_dimension()
        """
        if deadline is None:
            deadline = start_time + timedelta(days=3)

        horizon_sec = int((deadline - start_time).total_seconds())

        routing.AddDimension(
            transit_cb_index,
            max_wait_time_sec,
            horizon_sec,
            False,
            "Time",
        )

        time_dimension = routing.GetDimensionOrDie("Time")

        # depot 起點:允許在 [0, horizon] 內出發
        start_index = routing.Start(0)
        time_dimension.CumulVar(start_index).SetRange(0, horizon_sec)

        for node in range(1, len(node_meta)):
            meta = node_meta[node]
            if meta["type"] != "poi":
                continue

            index = manager.NodeToIndex(node)
            task_idx = meta["task_idx"]
            task = tasks[task_idx]

            poi_tw = meta.get("poi_time_window")
            task_tw = task.time_window

            # 計算有效時間窗口
            start_sec, end_sec = self.tw_handler.compute_effective_time_window(
                task_tw, poi_tw, start_time, horizon_sec
            )

            if start_sec > end_sec:
                # 完全無交集 → 強制一個無效的 0 範圍,讓 solver 自己避免
                logger.warning(
                    "Node(%s) has infeasible time window, forcing tiny 0 range.",
                    meta,
                )
                start_sec = end_sec = 0

            time_dimension.CumulVar(index).SetRange(start_sec, end_sec)

        end_index = routing.End(0)
        time_dimension.CumulVar(end_index).SetRange(0, horizon_sec)

        return time_dimension

    @staticmethod
    def _add_priority_disjunctions(
            routing: pywrapcp.RoutingModel,
            manager: pywrapcp.RoutingIndexManager,
            tasks: List[_Task],
            node_meta: List[Dict[str, Any]],
    ) -> None:
        """
        添加優先級約束

        完全保留原始邏輯: _add_priority_disjunctions()
        """
        task_nodes: Dict[int, List[int]] = {i: [] for i in range(len(tasks))}

        for node in range(1, len(node_meta)):
            meta = node_meta[node]
            if meta["type"] != "poi":
                continue
            task_idx = meta["task_idx"]
            task_nodes[task_idx].append(node)

        for task_idx, nodes in task_nodes.items():
            if not nodes:
                continue

            task = tasks[task_idx]
            priority = task.priority

            # 根據優先級設定 penalty
            if priority == "HIGH":
                penalty = 10_000_000
            elif priority == "MEDIUM":
                penalty = 100_000
            else:
                penalty = 10_000

            routing_indices = [manager.NodeToIndex(n) for n in nodes]
            routing.AddDisjunction(routing_indices, penalty)

    def _create_search_parameters(self) -> pywrapcp.DefaultRoutingSearchParameters:
        """
        創建搜索參數

        完全保留原始邏輯
        """
        search_parameters = pywrapcp.DefaultRoutingSearchParameters()
        search_parameters.first_solution_strategy = (
            routing_enums_pb2.FirstSolutionStrategy.PATH_CHEAPEST_ARC
        )
        search_parameters.local_search_metaheuristic = (
            routing_enums_pb2.LocalSearchMetaheuristic.GUIDED_LOCAL_SEARCH
        )
        search_parameters.time_limit.FromSeconds(self.time_limit_seconds)

        return search_parameters