Kunitomi's picture
Upload folder using huggingface_hub
196c526 verified
"""Minimal utilities for Bean Vision."""
import logging
import torch
from pathlib import Path
from typing import Union
class ModelError(Exception):
"""Custom exception for model-related errors."""
pass
def validate_device(device_str: Union[str, torch.device]) -> torch.device:
"""Validate and return torch device."""
if isinstance(device_str, torch.device):
return device_str
try:
device = torch.device(device_str)
# Check if device is available
if device.type == 'cuda' and not torch.cuda.is_available():
logging.warning("CUDA requested but not available, falling back to CPU")
device = torch.device('cpu')
elif device.type == 'mps' and not torch.backends.mps.is_available():
logging.warning("MPS requested but not available, falling back to CPU")
device = torch.device('cpu')
return device
except Exception as e:
logging.warning(f"Invalid device '{device_str}': {e}. Using CPU")
return torch.device('cpu')
def safe_load_model_checkpoint(checkpoint_path: Union[str, Path], device: torch.device):
"""Safely load a model checkpoint with proper error handling."""
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise ModelError(f"Checkpoint file not found: {checkpoint_path}")
try:
checkpoint = torch.load(checkpoint_path, map_location=device)
return checkpoint
except Exception as e:
raise ModelError(f"Failed to load checkpoint: {e}")