import os import platform import numpy as np try: import torch _HAS_TORCH = True except ImportError: _HAS_TORCH = False try: import mlx.core as mx # Verify it can run _ = mx.array([1.0]) _HAS_MLX = True except (ImportError, RuntimeError, AttributeError): _HAS_MLX = False def get_framework(): """ Detect the high-level compute framework. Priority: 1. SCIMLX_BACKEND environment variable ('torch' or 'mlx'). 2. 'mlx' if on Apple Silicon and mlx is installed. 3. 'torch' as default. """ env_backend = os.environ.get("SCIMLX_BACKEND", "").lower() if env_backend == "mlx" and _HAS_MLX: return "mlx" if env_backend == "torch" and _HAS_TORCH: return "torch" # Auto-detection: MLX is preferred on Apple Silicon if available if _HAS_MLX and platform.system() == "Darwin" and platform.machine() == "arm64": return "mlx" return "torch" if _HAS_TORCH else "numpy" FRAMEWORK = get_framework() def get_framework_backend(): """ Detect the specific hardware backend. Returns: 'MLX', 'CUDA', 'MPS', or 'CPU'. """ if FRAMEWORK == "mlx": return "MLX" if FRAMEWORK == "torch": if torch.cuda.is_available(): return "CUDA" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return "MPS" return "CPU" return "CPU" BACKEND = get_framework_backend() def get_device(): """Return the framework-specific device object or string.""" if FRAMEWORK == "mlx": return "mlx" if FRAMEWORK == "torch": backend = get_framework_backend() if backend == "CUDA": return torch.device("cuda") if backend == "MPS": return torch.device("mps") return torch.device("cpu") return "cpu" DEVICE = get_device() def get_torch_device(): """Always returns a valid torch.device, regardless of global FRAMEWORK.""" if not _HAS_TORCH: return None if torch.cuda.is_available(): return torch.device("cuda") if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") TORCH_DEVICE = get_torch_device() def to_array(data, dtype=None): """ Convert data to framework-native array/tensor. Ensures data is on the correct device. """ if FRAMEWORK == "mlx": if isinstance(data, mx.array): return data if dtype is None else data.astype(dtype) # Handle torch tensors if passed to mlx to_array if _HAS_TORCH and torch.is_tensor(data): data = data.detach().cpu().numpy() return mx.array(data, dtype=dtype) if FRAMEWORK == "torch": if torch.is_tensor(data): res = data.to(DEVICE) else: # Handle mlx arrays if passed to torch to_array if _HAS_MLX and isinstance(data, mx.array): data = np.array(data) res = torch.as_tensor(data, device=DEVICE) if dtype is not None: # Handle common dtype conversions if isinstance(dtype, str): if dtype == "float32": dtype = torch.float32 elif dtype == "float64": dtype = torch.float64 res = res.to(dtype) return res # Fallback to numpy if no backend framework is available res = np.array(data) if dtype is not None: res = res.astype(dtype) return res def to_device(data): """ Move tensors or models to the active device. If global FRAMEWORK is 'mlx', it still ensures torch models/tensors are moved to TORCH_DEVICE if they are torch objects. """ if isinstance(data, (torch.Tensor, torch.nn.Module)): return data.to(TORCH_DEVICE) if FRAMEWORK == "mlx": # For MLX arrays, it's mostly a no-op if _HAS_MLX and isinstance(data, mx.array): return data # If it's a numpy array, convert to mx.array? # Actually TrainerMLX handles that. return data if FRAMEWORK == "torch": if isinstance(data, (torch.Tensor, torch.nn.Module)): return data.to(DEVICE) if isinstance(data, dict): return {k: to_device(v) for k, v in data.items()} if isinstance(data, (list, tuple)): return type(data)(to_device(v) for v in data) return data def to_framework_device(data): """Alias for to_device for backward compatibility.""" return to_device(data)