LifeFlow-AI / src /optimization /tsptw_solver.py
Marco310's picture
Release v1.0.0: Hybrid AI Architecture, Modern UI Overhaul, and Performance Optimizations
1a2b0fa
import os
from typing import List, Dict, Any, Optional
from datetime import datetime
from src.infra.logger import get_logger
from src.services.googlemap_api_service import GoogleMapAPIService
from src.optimization.models import (
convert_tasks_to_internal,
convert_location_to_internal,
convert_result_to_dict,
)
from src.optimization.graph import GraphBuilder
from src.optimization.solver import ORToolsSolver, SolutionExtractor
logger = get_logger(__name__)
class TSPTWSolver:
"""
TSPTW (Traveling Salesman Problem with Time Windows) 求解器
✅ 完全保留原始功能:
- 外部 API 使用 Dict (向後兼容)
- 內部使用 Pydantic (類型檢查 + 驗證)
- 時間單位使用分鐘(service_duration_min)
- 時間窗同時支援 Task-level & POI-level (含多段 time_windows)
- 備選 POI 會考慮 time window 且不再推薦同一個 poi_id
✨ 重構改進:
- 模塊化架構(易於測試和維護)
- 清晰的職責分離
- 保持對外 API 完全不變
"""
def __init__(
self,
time_limit_seconds: Optional[int] = None,
verbose: bool = False,
):
"""
初始化求解器
Args:
api_key: Google Maps API Key
time_limit_seconds: 求解時間限制(秒)
verbose: 是否顯示詳細日誌
"""
env_limit = (
os.getenv("SOLVER_TIME_LIMIT")
or os.getenv("solver_time_limit")
or "1"
)
self.time_limit_seconds = (
time_limit_seconds if time_limit_seconds is not None else int(env_limit)
)
self.verbose = verbose
# 初始化各模塊
self.graph_builder = GraphBuilder()
self.ortools_solver = ORToolsSolver(
time_limit_seconds=self.time_limit_seconds,
verbose=verbose,
)
self.solution_extractor = SolutionExtractor()
# ------------------------------------------------------------------ #
# Public API - 完全保留原始接口 #
# ------------------------------------------------------------------ #
def solve(
self,
start_location: Dict[str, float],
start_time: datetime,
deadline: datetime,
tasks: List[Dict[str, Any]] = None,
travel_mode="DRIVE",
max_wait_time_min: int = 10,
alt_k: int = 3,
return_to_start: bool = True,
) -> Dict[str, Any]:
"""
求解 TSPTW
✅ 完全保留原始 API 和功能
Args:
tasks: 任務列表,每個任務格式:
{
"task_id": str,
"priority": "HIGH" | "MEDIUM" | "LOW",
"time_window": (datetime, datetime) | None,
"service_duration_min": int,
"candidates": [
{
"poi_id": str,
"lat": float,
"lng": float,
"time_window": (datetime, datetime) | None,
"time_windows": [(datetime, datetime), ...] | None
}
]
}
start_location: {"lat": float, "lng": float}
start_time: 開始時間
deadline: 截止時間
max_wait_time_min: 最大等待時間(分鐘)
travel_mode: 矩陣計算的交通模式
alt_k: 回傳 Top-K 備選 POI
return_to_start: 是否回到出發點
Returns: Dict(由 _TSPTWResult 轉出)
{
"status": "OK" | "NO_SOLUTION" | "NO_TASKS",
"total_travel_time_min": int,
"total_travel_distance_m": int,
"route": [...],
"visited_tasks": [...],
"skipped_tasks": [...],
"tasks_detail": [...]
}
"""
logger.info("TSPTWSolver.solve() start, tasks=%d", len(tasks))
# 1. 驗證和轉換輸入
try:
internal_tasks = convert_tasks_to_internal(tasks)
internal_start_location = convert_location_to_internal(start_location)
except Exception as e:
logger.error(f"Failed to validate input: {e}")
return {
"status": "INVALID_INPUT",
"error": str(e),
"total_travel_time_min": 0,
"total_travel_distance_m": 0,
"route": [],
"visited_tasks": [],
"skipped_tasks": [t.get("task_id", "") for t in tasks],
"tasks_detail": [],
}
# 2. 構建圖
graph = self.graph_builder.build_graph(
start_location=internal_start_location,
tasks=internal_tasks,
travel_mode=travel_mode,
)
num_nodes = len(graph.node_meta)
if num_nodes <= 1:
logger.warning("No POIs to visit, only depot.")
return {
"status": "NO_TASKS",
"total_travel_time_min": 0,
"total_travel_distance_m": 0,
"route": [],
"visited_tasks": [],
"skipped_tasks": [t.task_id for t in internal_tasks],
"tasks_detail": [],
}
# 3. 求解
max_wait_time_sec = max_wait_time_min * 60
try:
routing, manager, solution = self.ortools_solver.solve(
graph=graph,
tasks=internal_tasks,
start_time=start_time,
deadline=deadline,
max_wait_time_sec=max_wait_time_sec,
)
except Exception as e:
logger.error(f"OR-Tools solver failed: {e}")
return {
"status": "SOLVER_ERROR",
"error": str(e),
"total_travel_time_min": 0,
"total_travel_distance_m": 0,
"route": [],
"visited_tasks": [],
"skipped_tasks": [t.task_id for t in internal_tasks],
"tasks_detail": [],
}
# 4. 檢查是否有解
if solution is None:
logger.warning("No solution found")
return {
"status": "NO_SOLUTION",
"total_travel_time_min": 0,
"total_travel_distance_m": 0,
"route": [],
"visited_tasks": [],
"skipped_tasks": [t.task_id for t in internal_tasks],
"tasks_detail": [],
}
# 5. 提取結果
time_dimension = routing.GetDimensionOrDie("Time")
result = self.solution_extractor.extract(
routing=routing,
manager=manager,
solution=solution,
time_dimension=time_dimension,
start_time=start_time,
graph=graph,
tasks=internal_tasks,
alt_k=alt_k,
return_to_start=return_to_start,
)
logger.info("TSPTWSolver.solve() done, status=%s", result.status)
# 6. 轉換為外部 Dict
return convert_result_to_dict(result)
def test_time_window_handler():
from datetime import datetime, timezone, timedelta
from src.optimization.graph.time_window_handler import TimeWindowHandler
handler = TimeWindowHandler()
tz = timezone(timedelta(hours=8)) # UTC+8
start_time = datetime(2025, 11, 22, 10, 0, 0, tzinfo=tz)
horizon_sec = 8 * 3600 # 8 hours
print("=== Test Case 1: 都沒有時間窗口 ===")
start, end = handler.compute_effective_time_window(None, None, start_time, horizon_sec)
assert start == 0 and end == horizon_sec
print(f"✅ Pass: [{start}, {end}]")
print("\n=== Test Case 2: Dict 格式 - 只有 task 有時間窗口 ===")
task_tw = {
'earliest_time': datetime(2025, 11, 22, 11, 0, 0, tzinfo=tz),
'latest_time': datetime(2025, 11, 22, 15, 0, 0, tzinfo=tz)
}
start, end = handler.compute_effective_time_window(task_tw, None, start_time, horizon_sec)
assert start == 3600 # 1 hour after start
assert end == 18000 # 5 hours after start
print(f"✅ Pass: [{start}, {end}]")
print("\n=== Test Case 3: Tuple 格式 - 只有 POI 有時間窗口 ===")
poi_tw = (
datetime(2025, 11, 22, 9, 0, 0, tzinfo=tz), # 開放時間
datetime(2025, 11, 22, 17, 0, 0, tzinfo=tz) # 關門時間
)
start, end = handler.compute_effective_time_window(None, poi_tw, start_time, horizon_sec)
assert start == 0 # POI 已經開門
assert end == 25200 # 7 hours after start
print(f"✅ Pass: [{start}, {end}]")
print("\n=== Test Case 4: 字符串格式 ===")
task_tw_str = {
'earliest_time': '2025-11-22T11:00:00+08:00',
'latest_time': '2025-11-22T15:00:00+08:00'
}
start, end = handler.compute_effective_time_window(task_tw_str, None, start_time, horizon_sec)
assert start == 3600
assert end == 18000
print(f"✅ Pass: [{start}, {end}]")
print("\n=== Test Case 5: 部分時間窗口 (只有 earliest) ===")
partial_tw = {
'earliest_time': datetime(2025, 11, 22, 12, 0, 0, tzinfo=tz),
'latest_time': None
}
start, end = handler.compute_effective_time_window(partial_tw, None, start_time, horizon_sec)
assert start == 7200 # 2 hours after start
assert end == horizon_sec
print(f"✅ Pass: [{start}, {end}]")
print("\n=== Test Case 6: 部分時間窗口 (只有 latest) ===")
partial_tw = {
'earliest_time': None,
'latest_time': datetime(2025, 11, 22, 16, 0, 0, tzinfo=tz)
}
start, end = handler.compute_effective_time_window(partial_tw, None, start_time, horizon_sec)
assert start == 0
assert end == 21600 # 6 hours after start
print(f"✅ Pass: [{start}, {end}]")
print("\n=== Test Case 7: 實際場景 - Scout 返回的 POI time_window = None ===")
poi_data = {
'place_id': 'ChIJ...',
'name': 'Rainbow Village',
'time_window': None # 你的實際情況
}
start, end = handler.compute_effective_time_window(task_tw, poi_data.get('time_window'), start_time, horizon_sec)
assert start == 3600 and end == 18000
print(f"✅ Pass: [{start}, {end}]")
print("\n🎉 All tests passed!")
if __name__ == "__main__":
test_time_window_handler()