Spaces:
Build error
Build error
| """ | |
| 修复循环引擎 | |
| 控制修复循环的启动、暂停、停止,包含条件判断和超时机制 | |
| """ | |
| import asyncio | |
| import logging | |
| from typing import Dict, List, Optional, Any, Callable, Set, Tuple | |
| from dataclasses import dataclass, field | |
| from datetime import datetime, timedelta | |
| from enum import Enum | |
| import threading | |
| import time | |
| from pathlib import Path | |
| from data_models import SpaceInfo, ErrorInfo, RepairStrategy, SpaceStatus | |
| from auto_repair_executor import AutoRepairExecutor | |
| class LoopState(Enum): | |
| """循环状态""" | |
| STOPPED = "stopped" | |
| STARTING = "starting" | |
| RUNNING = "running" | |
| PAUSING = "pausing" | |
| PAUSED = "paused" | |
| STOPPING = "stopping" | |
| ERROR = "error" | |
| class TerminationReason(Enum): | |
| """终止原因""" | |
| MANUAL = "manual" | |
| SUCCESS = "success" | |
| TIMEOUT = "timeout" | |
| MAX_ITERATIONS = "max_iterations" | |
| ERROR = "error" | |
| NO_PROGRESS = "no_progress" | |
| RESOURCE_EXHAUSTED = "resource_exhausted" | |
| class LoopConfig: | |
| """循环配置""" | |
| max_iterations: int = 10 # 最大迭代次数 | |
| timeout_minutes: int = 60 # 超时时间(分钟) | |
| check_interval_seconds: int = 30 # 检查间隔(秒) | |
| success_wait_seconds: int = 60 # 成功后等待时间 | |
| failure_wait_seconds: int = 120 # 失败后等待时间 | |
| enable_progress_check: bool = True # 启用进度检查 | |
| no_progress_timeout_minutes: int = 15 # 无进度超时(分钟) | |
| max_concurrent_repairs: int = 3 # 最大并发修复数 | |
| class LoopStatistics: | |
| """循环统计""" | |
| start_time: datetime | |
| iterations: int = 0 | |
| successful_repairs: int = 0 | |
| failed_repairs: int = 0 | |
| total_repair_time: float = 0.0 | |
| current_iteration_start: Optional[datetime] = None | |
| last_successful_repair: Optional[datetime] = None | |
| last_error: Optional[str] = None | |
| termination_reason: Optional[TerminationReason] = None | |
| class ConditionEvaluator: | |
| """条件评估器""" | |
| def __init__(self): | |
| self.logger = logging.getLogger(__name__) | |
| def should_continue_loop(self, stats: LoopStatistics, config: LoopConfig) -> Tuple[bool, Optional[str]]: | |
| """判断是否应该继续循环""" | |
| # 检查最大迭代次数 | |
| if stats.iterations >= config.max_iterations: | |
| return False, f"达到最大迭代次数: {config.max_iterations}" | |
| # 检查超时 | |
| elapsed_time = (datetime.now() - stats.start_time).total_seconds() | |
| timeout_seconds = config.timeout_minutes * 60 | |
| if elapsed_time >= timeout_seconds: | |
| return False, f"循环超时: {config.timeout_minutes} 分钟" | |
| # 检查进度 | |
| if config.enable_progress_check and self._check_no_progress(stats, config): | |
| return False, f"长期无进展: {config.no_progress_timeout_minutes} 分钟" | |
| return True, None | |
| def _check_no_progress(self, stats: LoopStatistics, config: LoopConfig) -> bool: | |
| """检查是否有进展""" | |
| if not stats.last_successful_repair: | |
| return stats.iterations > 3 # 前3次给机会 | |
| no_progress_time = (datetime.now() - stats.last_successful_repair).total_seconds() | |
| timeout_seconds = config.no_progress_timeout_minutes * 60 | |
| return no_progress_time >= timeout_seconds | |
| def should_attempt_repair(self, space_info: SpaceInfo, last_status: Optional[SpaceStatus]) -> bool: | |
| """判断是否应该尝试修复""" | |
| # 如果当前状态不是错误,不需要修复 | |
| if space_info.current_status != SpaceStatus.ERROR: | |
| return False | |
| # 如果上次状态也是错误,可能还在处理中 | |
| if last_status == SpaceStatus.ERROR: | |
| return False | |
| return True | |
| def evaluate_repair_success(self, previous_status: SpaceStatus, current_status: SpaceStatus, | |
| error_before: Optional[ErrorInfo], error_after: Optional[ErrorInfo]) -> bool: | |
| """评估修复是否成功""" | |
| # 状态从错误变为非错误 | |
| if previous_status == SpaceStatus.ERROR and current_status != SpaceStatus.ERROR: | |
| return True | |
| # 错误信息减少或消失 | |
| if error_before and not error_after: | |
| return True | |
| if error_before and error_after: | |
| # 错误类型改变,可能有问题 | |
| if error_before.error_type != error_after.error_type: | |
| return False | |
| # 置信度降低,可能有问题 | |
| if error_after.confidence < error_before.confidence * 0.5: | |
| return False | |
| return False | |
| def calculate_wait_time(self, repair_success: bool, config: LoopConfig) -> int: | |
| """计算等待时间""" | |
| if repair_success: | |
| return config.success_wait_seconds | |
| else: | |
| return config.failure_wait_seconds | |
| class TimeoutManager: | |
| """超时管理器""" | |
| def __init__(self): | |
| self.logger = logging.getLogger(__name__) | |
| self.timeouts: Dict[str, datetime] = {} | |
| def set_timeout(self, key: str, timeout_seconds: int) -> None: | |
| """设置超时""" | |
| expire_time = datetime.now() + timedelta(seconds=timeout_seconds) | |
| self.timeouts[key] = expire_time | |
| self.logger.debug(f"设置超时: {key} - {timeout_seconds} 秒") | |
| def is_expired(self, key: str) -> bool: | |
| """检查是否超时""" | |
| if key not in self.timeouts: | |
| return True | |
| return datetime.now() > self.timeouts[key] | |
| def get_remaining_time(self, key: str) -> Optional[float]: | |
| """获取剩余时间""" | |
| if key not in self.timeouts: | |
| return None | |
| remaining = (self.timeouts[key] - datetime.now()).total_seconds() | |
| return max(0, remaining) | |
| def cancel_timeout(self, key: str) -> None: | |
| """取消超时""" | |
| if key in self.timeouts: | |
| del self.timeouts[key] | |
| self.logger.debug(f"取消超时: {key}") | |
| def cleanup_expired(self) -> None: | |
| """清理过期的超时""" | |
| current_time = datetime.now() | |
| expired_keys = [ | |
| key for key, expire_time in self.timeouts.items() | |
| if current_time > expire_time | |
| ] | |
| for key in expired_keys: | |
| del self.timeouts[key] | |
| self.logger.debug(f"清理过期超时: {key}") | |
| class LoopController: | |
| """循环控制器""" | |
| def __init__(self, config: LoopConfig): | |
| self.logger = logging.getLogger(__name__) | |
| self.config = config | |
| self.state = LoopState.STOPPED | |
| self.stats = None | |
| self.condition_evaluator = ConditionEvaluator() | |
| self.timeout_manager = TimeoutManager() | |
| # 控制标志 | |
| self._stop_requested = threading.Event() | |
| self._pause_requested = threading.Event() | |
| self._lock = threading.Lock() | |
| # 回调函数 | |
| self.on_iteration_start: Optional[Callable] = None | |
| self.on_iteration_complete: Optional[Callable] = None | |
| self.on_loop_complete: Optional[Callable] = None | |
| self.on_error: Optional[Callable] = None | |
| async def start_loop(self) -> None: | |
| """启动循环""" | |
| with self._lock: | |
| if self.state != LoopState.STOPPED: | |
| raise RuntimeError(f"循环已在运行或正在启动: {self.state.value}") | |
| self.state = LoopState.STARTING | |
| self._stop_requested.clear() | |
| self._pause_requested.clear() | |
| try: | |
| await self._run_loop() | |
| except Exception as e: | |
| with self._lock: | |
| self.state = LoopState.ERROR | |
| self.logger.error(f"循环运行异常: {e}") | |
| if self.on_error: | |
| await self._safe_call(self.on_error, e) | |
| async def _run_loop(self) -> None: | |
| """运行主循环""" | |
| self.stats = LoopStatistics(start_time=datetime.now()) | |
| with self._lock: | |
| self.state = LoopState.RUNNING | |
| self.logger.info("修复循环已启动") | |
| try: | |
| while True: | |
| # 检查停止请求 | |
| if self._stop_requested.is_set(): | |
| self.logger.info("收到停止请求") | |
| break | |
| # 检查暂停请求 | |
| if self._pause_requested.is_set(): | |
| with self._lock: | |
| self.state = LoopState.PAUSED | |
| self.logger.info("循环已暂停") | |
| await self._wait_for_resume() | |
| with self._lock: | |
| self.state = LoopState.RUNNING | |
| self.logger.info("循环已恢复") | |
| continue | |
| # 执行一次迭代 | |
| iteration_result = await self._execute_iteration() | |
| if not iteration_result.continue_loop: | |
| self.stats.termination_reason = iteration_result.termination_reason | |
| break | |
| # 等待下一次迭代 | |
| wait_time = iteration_result.wait_time | |
| if wait_time > 0: | |
| await asyncio.sleep(wait_time) | |
| finally: | |
| with self._lock: | |
| self.state = LoopState.STOPPED | |
| self.logger.info("修复循环已停止") | |
| if self.on_loop_complete: | |
| await self._safe_call(self.on_loop_complete, self.stats) | |
| async def _execute_iteration(self) -> Any: | |
| """执行一次迭代""" | |
| self.stats.iterations += 1 | |
| self.stats.current_iteration_start = datetime.now() | |
| # 调用迭代开始回调 | |
| if self.on_iteration_start: | |
| await self._safe_call(self.on_iteration_start, self.stats) | |
| try: | |
| # 判断是否应该继续循环 | |
| should_continue, reason = self.condition_evaluator.should_continue_loop(self.stats, self.config) | |
| if not should_continue: | |
| termination_reason = self._determine_termination_reason(reason) | |
| self.logger.info(f"循环终止: {reason}") | |
| return IterationResult(continue_loop=False, termination_reason=termination_reason, wait_time=0) | |
| # 执行修复逻辑(这里需要实际实现) | |
| repair_result = await self._attempt_repair() | |
| # 更新统计信息 | |
| if repair_result.success: | |
| self.stats.successful_repairs += 1 | |
| self.stats.last_successful_repair = datetime.now() | |
| wait_time = self.condition_evaluator.calculate_wait_time(True, self.config) | |
| else: | |
| self.stats.failed_repairs += 1 | |
| self.stats.last_error = repair_result.error_message | |
| wait_time = self.condition_evaluator.calculate_wait_time(False, self.config) | |
| # 更新总修复时间 | |
| iteration_time = (datetime.now() - self.stats.current_iteration_start).total_seconds() | |
| self.stats.total_repair_time += iteration_time | |
| return IterationResult( | |
| continue_loop=True, | |
| termination_reason=None, | |
| wait_time=wait_time, | |
| repair_success=repair_result.success | |
| ) | |
| except Exception as e: | |
| self.stats.failed_repairs += 1 | |
| self.stats.last_error = str(e) | |
| self.logger.error(f"迭代执行异常: {e}") | |
| return IterationResult( | |
| continue_loop=True, | |
| termination_reason=None, | |
| wait_time=self.config.failure_wait_seconds, | |
| repair_success=False, | |
| error_message=str(e) | |
| ) | |
| finally: | |
| # 调用迭代完成回调 | |
| if self.on_iteration_complete: | |
| await self._safe_call(self.on_iteration_complete, self.stats) | |
| async def _attempt_repair(self) -> Any: | |
| """尝试修复(需要实际实现)""" | |
| # 这里应该调用实际的修复逻辑 | |
| # 目前返回示例结果 | |
| return RepairResult(success=False, error_message="需要实现具体修复逻辑") | |
| def _determine_termination_reason(self, reason: str) -> TerminationReason: | |
| """确定终止原因""" | |
| if "迭代次数" in reason: | |
| return TerminationReason.MAX_ITERATIONS | |
| elif "超时" in reason: | |
| return TerminationReason.TIMEOUT | |
| elif "无进展" in reason: | |
| return TerminationReason.NO_PROGRESS | |
| elif "资源" in reason: | |
| return TerminationReason.RESOURCE_EXHAUSTED | |
| else: | |
| return TerminationReason.SUCCESS | |
| async def _wait_for_resume(self) -> None: | |
| """等待恢复""" | |
| while self._pause_requested.is_set() and not self._stop_requested.is_set(): | |
| await asyncio.sleep(1) | |
| async def _safe_call(self, callback: Callable, *args) -> None: | |
| """安全调用回调函数""" | |
| try: | |
| if asyncio.iscoroutinefunction(callback): | |
| await callback(*args) | |
| else: | |
| callback(*args) | |
| except Exception as e: | |
| self.logger.error(f"回调函数执行异常: {e}") | |
| def stop(self) -> None: | |
| """停止循环""" | |
| self._stop_requested.set() | |
| self.logger.info("请求停止循环") | |
| def pause(self) -> None: | |
| """暂停循环""" | |
| self._pause_requested.set() | |
| self.logger.info("请求暂停循环") | |
| def resume(self) -> None: | |
| """恢复循环""" | |
| self._pause_requested.clear() | |
| self.logger.info("请求恢复循环") | |
| def get_state(self) -> LoopState: | |
| """获取当前状态""" | |
| return self.state | |
| def get_statistics(self) -> Optional[LoopStatistics]: | |
| """获取统计信息""" | |
| return self.stats | |
| class IterationResult: | |
| """迭代结果""" | |
| continue_loop: bool | |
| termination_reason: Optional[TerminationReason] | |
| wait_time: int | |
| repair_success: Optional[bool] = None | |
| error_message: Optional[str] = None | |
| class RepairResult: | |
| """修复结果""" | |
| success: bool | |
| error_message: Optional[str] = None | |
| commit_sha: Optional[str] = None | |
| repair_time: Optional[float] = None | |
| class RepairLoopEngine: | |
| """修复循环引擎主类""" | |
| def __init__(self, repair_executor: AutoRepairExecutor, config: LoopConfig): | |
| self.logger = logging.getLogger(__name__) | |
| self.repair_executor = repair_executor | |
| self.config = config | |
| # 循环控制器 | |
| self.controller = LoopController(config) | |
| # 监控的 Spaces | |
| self.monitored_spaces: Dict[str, SpaceInfo] = {} | |
| self.space_errors: Dict[str, ErrorInfo] = {} | |
| self.last_space_status: Dict[str, SpaceStatus] = {} | |
| # 设置回调 | |
| self._setup_callbacks() | |
| # 并发控制 | |
| self.active_repairs: Set[str] = set() | |
| self.repair_lock = asyncio.Lock() | |
| def _setup_callbacks(self) -> None: | |
| """设置回调函数""" | |
| self.controller.on_iteration_start = self._on_iteration_start | |
| self.controller.on_iteration_complete = self._on_iteration_complete | |
| self.controller.on_loop_complete = self._on_loop_complete | |
| self.controller.on_error = self._on_error | |
| async def _on_iteration_start(self, stats: LoopStatistics) -> None: | |
| """迭代开始回调""" | |
| self.logger.info(f"开始第 {stats.iterations} 次迭代") | |
| async def _on_iteration_complete(self, stats: LoopStatistics) -> None: | |
| """迭代完成回调""" | |
| success_rate = stats.successful_repairs / max(stats.iterations, 1) * 100 | |
| avg_time = stats.total_repair_time / max(stats.iterations, 1) | |
| self.logger.info( | |
| f"迭代 {stats.iterations} 完成 - " | |
| f"成功率: {success_rate:.1f}%, " | |
| f"平均时间: {avg_time:.1f}秒" | |
| ) | |
| async def _on_loop_complete(self, stats: LoopStatistics) -> None: | |
| """循环完成回调""" | |
| total_time = (datetime.now() - stats.start_time).total_seconds() | |
| success_rate = stats.successful_repairs / max(stats.iterations, 1) * 100 | |
| self.logger.info( | |
| f"修复循环完成 - " | |
| f"总时间: {total_time:.1f}秒, " | |
| f"迭代次数: {stats.iterations}, " | |
| f"成功修复: {stats.successful_repairs}, " | |
| f"失败修复: {stats.failed_repairs}, " | |
| f"成功率: {success_rate:.1f}%, " | |
| f"终止原因: {stats.termination_reason.value if stats.termination_reason else 'unknown'}" | |
| ) | |
| async def _on_error(self, error: Exception) -> None: | |
| """错误回调""" | |
| self.logger.error(f"循环执行错误: {error}") | |
| def add_space(self, space_info: SpaceInfo) -> None: | |
| """添加要监控的 Space""" | |
| self.monitored_spaces[space_info.space_id] = space_info | |
| self.logger.info(f"添加监控 Space: {space_info.space_id}") | |
| def remove_space(self, space_id: str) -> None: | |
| """移除监控的 Space""" | |
| if space_id in self.monitored_spaces: | |
| del self.monitored_spaces[space_id] | |
| if space_id in self.space_errors: | |
| del self.space_errors[space_id] | |
| if space_id in self.last_space_status: | |
| del self.last_space_status[space_id] | |
| self.logger.info(f"移除监控 Space: {space_id}") | |
| def update_space_status(self, space_id: str, status: SpaceStatus, | |
| error_info: Optional[ErrorInfo] = None) -> None: | |
| """更新 Space 状态""" | |
| self.last_space_status[space_id] = status | |
| if error_info: | |
| self.space_errors[space_id] = error_info | |
| self.logger.debug(f"更新 Space 状态: {space_id} -> {status.value}") | |
| async def _attempt_repair(self) -> RepairResult: | |
| """尝试修复""" | |
| start_time = datetime.now() | |
| try: | |
| # 查找需要修复的 Space | |
| space_to_repair = None | |
| error_to_fix = None | |
| for space_id, space_info in self.monitored_spaces.items(): | |
| last_status = self.last_space_status.get(space_id) | |
| current_status = space_info.current_status | |
| current_error = self.space_errors.get(space_id) | |
| if self.controller.condition_evaluator.should_attempt_repair(space_info, last_status): | |
| # 检查是否已经在修复中 | |
| async with self.repair_lock: | |
| if space_id in self.active_repairs: | |
| continue | |
| if len(self.active_repairs) >= self.config.max_concurrent_repairs: | |
| break | |
| space_to_repair = space_info | |
| error_to_fix = current_error | |
| self.active_repairs.add(space_id) | |
| break | |
| if not space_to_repair or not error_to_fix: | |
| return RepairResult(success=False, error_message="没有需要修复的 Space") | |
| # 生成修复策略(这里需要实际实现) | |
| strategy = await self._generate_repair_strategy(error_to_fix, space_to_repair) | |
| if not strategy: | |
| self.active_repairs.discard(space_to_repair.space_id) | |
| return RepairResult(success=False, error_message="无法生成修复策略") | |
| # 执行修复 | |
| success, commit_sha = await self.repair_executor.execute_repair( | |
| space_to_repair, error_to_fix, strategy | |
| ) | |
| # 计算修复时间 | |
| repair_time = (datetime.now() - start_time).total_seconds() | |
| # 更新 Space 状态(这里应该实际检查状态) | |
| # await self._update_space_after_repair(space_to_repair.space_id) | |
| return RepairResult( | |
| success=success, | |
| error_message=None if success else "修复执行失败", | |
| commit_sha=commit_sha, | |
| repair_time=repair_time | |
| ) | |
| except Exception as e: | |
| return RepairResult(success=False, error_message=str(e)) | |
| finally: | |
| # 清理活跃修复记录 | |
| if space_to_repair: | |
| async with self.repair_lock: | |
| self.active_repairs.discard(space_to_repair.space_id) | |
| async def _generate_repair_strategy(self, error_info: ErrorInfo, space_info: SpaceInfo) -> Optional[RepairStrategy]: | |
| """生成修复策略(需要实际实现)""" | |
| # 这里应该调用实际的策略生成逻辑 | |
| # 目前返回 None 表示未实现 | |
| return None | |
| async def start(self) -> None: | |
| """启动修复循环""" | |
| if not self.monitored_spaces: | |
| raise ValueError("没有要监控的 Space") | |
| self.logger.info(f"启动修复循环,监控 {len(self.monitored_spaces)} 个 Space") | |
| await self.controller.start_loop() | |
| def stop(self) -> None: | |
| """停止修复循环""" | |
| self.controller.stop() | |
| def pause(self) -> None: | |
| """暂停修复循环""" | |
| self.controller.pause() | |
| def resume(self) -> None: | |
| """恢复修复循环""" | |
| self.controller.resume() | |
| def get_state(self) -> LoopState: | |
| """获取循环状态""" | |
| return self.controller.get_state() | |
| def get_statistics(self) -> Optional[LoopStatistics]: | |
| """获取统计信息""" | |
| return self.controller.get_statistics() | |
| def get_active_repairs(self) -> List[str]: | |
| """获取活跃的修复列表""" | |
| return list(self.active_repairs) | |
| def get_monitored_spaces(self) -> List[str]: | |
| """获取监控的 Space 列表""" | |
| return list(self.monitored_spaces.keys()) | |
| if __name__ == "__main__": | |
| # 示例用法 | |
| async def main(): | |
| # 创建配置 | |
| config = LoopConfig( | |
| max_iterations=5, | |
| timeout_minutes=30, | |
| check_interval_seconds=10 | |
| ) | |
| # 创建修复执行器(需要传入实际的 HF API 客户端) | |
| # hf_client = HuggingFaceAPIClient(token="your-token") | |
| # repair_executor = AutoRepairExecutor(hf_client) | |
| # 创建循环引擎 | |
| # loop_engine = RepairLoopEngine(repair_executor, config) | |
| # 添加监控的 Space | |
| # space_info = SpaceInfo( | |
| # space_id="test/test-space", | |
| # name="test-space", | |
| # repository_url="https://huggingface.co/spaces/test/test-space", | |
| # current_status=SpaceStatus.ERROR, | |
| # last_updated=datetime.now() | |
| # ) | |
| # loop_engine.add_space(space_info) | |
| # 启动循环 | |
| # await loop_engine.start() | |
| print("RepairLoopEngine 示例代码") | |
| asyncio.run(main()) |