"""Minimal utilities for Bean Vision.""" import logging import torch from pathlib import Path from typing import Union class ModelError(Exception): """Custom exception for model-related errors.""" pass def validate_device(device_str: Union[str, torch.device]) -> torch.device: """Validate and return torch device.""" if isinstance(device_str, torch.device): return device_str try: device = torch.device(device_str) # Check if device is available if device.type == 'cuda' and not torch.cuda.is_available(): logging.warning("CUDA requested but not available, falling back to CPU") device = torch.device('cpu') elif device.type == 'mps' and not torch.backends.mps.is_available(): logging.warning("MPS requested but not available, falling back to CPU") device = torch.device('cpu') return device except Exception as e: logging.warning(f"Invalid device '{device_str}': {e}. Using CPU") return torch.device('cpu') def safe_load_model_checkpoint(checkpoint_path: Union[str, Path], device: torch.device): """Safely load a model checkpoint with proper error handling.""" checkpoint_path = Path(checkpoint_path) if not checkpoint_path.exists(): raise ModelError(f"Checkpoint file not found: {checkpoint_path}") try: checkpoint = torch.load(checkpoint_path, map_location=device) return checkpoint except Exception as e: raise ModelError(f"Failed to load checkpoint: {e}")