""" Safety guards and runtime limits for PyTorch Playground demos. Ensures demos run within reasonable time and resource bounds. """ import platform import time import threading from typing import Callable, Any, Optional, Tuple from contextlib import contextmanager from functools import wraps from pytorch_playground.state import Level class TimeGuard: """ Context manager for enforcing time limits on operations. Usage: with TimeGuard(timeout_seconds=30) as guard: # do work if guard.should_stop: break """ def __init__(self, timeout_seconds: float = 60.0, check_interval: float = 0.1): self.timeout = timeout_seconds self.check_interval = check_interval self.start_time: Optional[float] = None self._should_stop = False self._timer: Optional[threading.Timer] = None def __enter__(self): self.start_time = time.time() # Set timer to flag stop self._timer = threading.Timer(self.timeout, self._set_stop) self._timer.daemon = True self._timer.start() return self def __exit__(self, exc_type, exc_val, exc_tb): if self._timer: self._timer.cancel() return False def _set_stop(self): self._should_stop = True @property def should_stop(self) -> bool: """Check if time limit has been reached.""" return self._should_stop @property def elapsed(self) -> float: """Get elapsed time in seconds.""" if self.start_time is None: return 0.0 return time.time() - self.start_time @property def remaining(self) -> float: """Get remaining time in seconds.""" return max(0.0, self.timeout - self.elapsed) def check(self) -> bool: """Check if we should continue. Returns True if OK to continue.""" return not self._should_stop def check_dataloader_workers(requested_workers: int) -> Tuple[int, Optional[str]]: """ Check and adjust DataLoader num_workers based on platform. Args: requested_workers: Number of workers requested Returns: Tuple of (adjusted_workers, warning_message) """ system = platform.system() warning = None if system == "Windows": if requested_workers > 0: warning = ( "Windows multiprocessing with DataLoader can cause issues. " "Setting num_workers=0. For production, consider using if __name__ == '__main__' guard." ) return 0, warning elif system == "Darwin": # macOS # macOS works but may need spawn method if requested_workers > 4: warning = f"Reduced workers from {requested_workers} to 4 on macOS for stability." return 4, warning # Linux or adjusted value return requested_workers, warning def cap_epochs(requested_epochs: int, level: Level) -> Tuple[int, Optional[str]]: """ Cap training epochs based on level. Args: requested_epochs: Number of epochs requested level: Current user level Returns: Tuple of (capped_epochs, warning_message) """ max_epochs = level.max_epochs warning = None if requested_epochs > max_epochs: warning = f"Epochs capped from {requested_epochs} to {max_epochs} for {level.value} level." return max_epochs, warning return requested_epochs, warning def cap_dataset_size( requested_size: int, level: Level, dataset_type: str = "default" ) -> Tuple[int, Optional[str]]: """ Cap dataset size based on level and type. Args: requested_size: Dataset size requested level: Current user level dataset_type: Type of dataset (affects limits) Returns: Tuple of (capped_size, warning_message) """ # Level-based limits limits = { Level.BEGINNER: { "default": 1000, "image": 5000, "synthetic": 2000, }, Level.INTERMEDIATE: { "default": 10000, "image": 20000, "synthetic": 10000, }, Level.ADVANCED: { "default": 60000, "image": 60000, "synthetic": 50000, }, } max_size = limits[level].get(dataset_type, limits[level]["default"]) warning = None if requested_size > max_size: warning = f"Dataset size capped from {requested_size} to {max_size} for {level.value} level." return max_size, warning return requested_size, warning def safe_run( func: Callable, timeout: float = 60.0, default: Any = None, ) -> Tuple[Any, Optional[str]]: """ Run a function with timeout and error handling. Args: func: Function to run timeout: Timeout in seconds default: Default value to return on failure Returns: Tuple of (result, error_message) """ result = default error = None # Simple timeout approach using threading # Note: This won't actually kill the thread if it hangs result_container = [default, None] def target(): try: result_container[0] = func() except Exception as e: result_container[1] = str(e) thread = threading.Thread(target=target) thread.daemon = True thread.start() thread.join(timeout=timeout) if thread.is_alive(): return default, f"Operation timed out after {timeout}s" return result_container[0], result_container[1] def require_torch(func: Callable) -> Callable: """Decorator that ensures PyTorch is available.""" @wraps(func) def wrapper(*args, **kwargs): try: import torch return func(*args, **kwargs) except ImportError: return None, "PyTorch is not installed. Please install it first." return wrapper def require_cuda(func: Callable) -> Callable: """Decorator that ensures CUDA is available.""" @wraps(func) def wrapper(*args, **kwargs): try: import torch if not torch.cuda.is_available(): return None, "CUDA is not available on this system." return func(*args, **kwargs) except ImportError: return None, "PyTorch is not installed." return wrapper def get_safe_batch_size( requested: int, dataset_size: int, min_batches: int = 2 ) -> Tuple[int, Optional[str]]: """ Ensure batch size doesn't exceed dataset size limits. Args: requested: Requested batch size dataset_size: Size of the dataset min_batches: Minimum number of batches required Returns: Tuple of (adjusted_batch_size, warning_message) """ max_batch = dataset_size // min_batches warning = None if max_batch < 1: max_batch = 1 if requested > max_batch: warning = f"Batch size reduced from {requested} to {max_batch} to ensure {min_batches}+ batches." return max_batch, warning return requested, warning @contextmanager def memory_guard(max_gb: float = 8.0): """ Context manager that warns about high memory usage. Note: This is advisory only, not a hard limit. """ try: import torch if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() initial = torch.cuda.memory_allocated() yield if torch.cuda.is_available(): peak = torch.cuda.max_memory_allocated() peak_gb = peak / (1024**3) if peak_gb > max_gb: print( f"Warning: Peak GPU memory usage ({peak_gb:.2f} GB) exceeded {max_gb} GB limit." ) except ImportError: yield def is_jupyter_environment() -> bool: """Check if running in a Jupyter environment.""" try: from IPython import get_ipython if get_ipython() is not None: return True except ImportError: pass return False def get_platform_info() -> dict: """Get platform information for debugging.""" import sys info = { "platform": platform.platform(), "system": platform.system(), "python_version": sys.version, "machine": platform.machine(), } try: import torch info["torch_version"] = torch.__version__ info["cuda_available"] = torch.cuda.is_available() if torch.cuda.is_available(): info["cuda_device_count"] = torch.cuda.device_count() if hasattr(torch.backends, "mps"): info["mps_available"] = torch.backends.mps.is_available() except ImportError: info["torch_available"] = False return info