|
|
"""Minimal utilities for Bean Vision.""" |
|
|
|
|
|
import logging |
|
|
import torch |
|
|
from pathlib import Path |
|
|
from typing import Union |
|
|
|
|
|
|
|
|
class ModelError(Exception): |
|
|
"""Custom exception for model-related errors.""" |
|
|
pass |
|
|
|
|
|
|
|
|
def validate_device(device_str: Union[str, torch.device]) -> torch.device: |
|
|
"""Validate and return torch device.""" |
|
|
if isinstance(device_str, torch.device): |
|
|
return device_str |
|
|
|
|
|
try: |
|
|
device = torch.device(device_str) |
|
|
|
|
|
|
|
|
if device.type == 'cuda' and not torch.cuda.is_available(): |
|
|
logging.warning("CUDA requested but not available, falling back to CPU") |
|
|
device = torch.device('cpu') |
|
|
elif device.type == 'mps' and not torch.backends.mps.is_available(): |
|
|
logging.warning("MPS requested but not available, falling back to CPU") |
|
|
device = torch.device('cpu') |
|
|
|
|
|
return device |
|
|
|
|
|
except Exception as e: |
|
|
logging.warning(f"Invalid device '{device_str}': {e}. Using CPU") |
|
|
return torch.device('cpu') |
|
|
|
|
|
|
|
|
def safe_load_model_checkpoint(checkpoint_path: Union[str, Path], device: torch.device): |
|
|
"""Safely load a model checkpoint with proper error handling.""" |
|
|
checkpoint_path = Path(checkpoint_path) |
|
|
|
|
|
if not checkpoint_path.exists(): |
|
|
raise ModelError(f"Checkpoint file not found: {checkpoint_path}") |
|
|
|
|
|
try: |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
return checkpoint |
|
|
except Exception as e: |
|
|
raise ModelError(f"Failed to load checkpoint: {e}") |