File size: 1,593 Bytes
196c526 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | """Minimal utilities for Bean Vision."""
import logging
import torch
from pathlib import Path
from typing import Union
class ModelError(Exception):
"""Custom exception for model-related errors."""
pass
def validate_device(device_str: Union[str, torch.device]) -> torch.device:
"""Validate and return torch device."""
if isinstance(device_str, torch.device):
return device_str
try:
device = torch.device(device_str)
# Check if device is available
if device.type == 'cuda' and not torch.cuda.is_available():
logging.warning("CUDA requested but not available, falling back to CPU")
device = torch.device('cpu')
elif device.type == 'mps' and not torch.backends.mps.is_available():
logging.warning("MPS requested but not available, falling back to CPU")
device = torch.device('cpu')
return device
except Exception as e:
logging.warning(f"Invalid device '{device_str}': {e}. Using CPU")
return torch.device('cpu')
def safe_load_model_checkpoint(checkpoint_path: Union[str, Path], device: torch.device):
"""Safely load a model checkpoint with proper error handling."""
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise ModelError(f"Checkpoint file not found: {checkpoint_path}")
try:
checkpoint = torch.load(checkpoint_path, map_location=device)
return checkpoint
except Exception as e:
raise ModelError(f"Failed to load checkpoint: {e}") |