SciMLx_Production / core /device.py
Moatasim Farooque
Remove problematic files
54fa103
import os
import platform
import numpy as np
try:
import torch
_HAS_TORCH = True
except ImportError:
_HAS_TORCH = False
try:
import mlx.core as mx
# Verify it can run
_ = mx.array([1.0])
_HAS_MLX = True
except (ImportError, RuntimeError, AttributeError):
_HAS_MLX = False
def get_framework():
"""
Detect the high-level compute framework.
Priority:
1. SCIMLX_BACKEND environment variable ('torch' or 'mlx').
2. 'mlx' if on Apple Silicon and mlx is installed.
3. 'torch' as default.
"""
env_backend = os.environ.get("SCIMLX_BACKEND", "").lower()
if env_backend == "mlx" and _HAS_MLX:
return "mlx"
if env_backend == "torch" and _HAS_TORCH:
return "torch"
# Auto-detection: MLX is preferred on Apple Silicon if available
if _HAS_MLX and platform.system() == "Darwin" and platform.machine() == "arm64":
return "mlx"
return "torch" if _HAS_TORCH else "numpy"
FRAMEWORK = get_framework()
def get_framework_backend():
"""
Detect the specific hardware backend.
Returns: 'MLX', 'CUDA', 'MPS', or 'CPU'.
"""
if FRAMEWORK == "mlx":
return "MLX"
if FRAMEWORK == "torch":
if torch.cuda.is_available():
return "CUDA"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "MPS"
return "CPU"
return "CPU"
BACKEND = get_framework_backend()
def get_device():
"""Return the framework-specific device object or string."""
if FRAMEWORK == "mlx":
return "mlx"
if FRAMEWORK == "torch":
backend = get_framework_backend()
if backend == "CUDA":
return torch.device("cuda")
if backend == "MPS":
return torch.device("mps")
return torch.device("cpu")
return "cpu"
DEVICE = get_device()
def get_torch_device():
"""Always returns a valid torch.device, regardless of global FRAMEWORK."""
if not _HAS_TORCH:
return None
if torch.cuda.is_available():
return torch.device("cuda")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
TORCH_DEVICE = get_torch_device()
def to_array(data, dtype=None):
"""
Convert data to framework-native array/tensor.
Ensures data is on the correct device.
"""
if FRAMEWORK == "mlx":
if isinstance(data, mx.array):
return data if dtype is None else data.astype(dtype)
# Handle torch tensors if passed to mlx to_array
if _HAS_TORCH and torch.is_tensor(data):
data = data.detach().cpu().numpy()
return mx.array(data, dtype=dtype)
if FRAMEWORK == "torch":
if torch.is_tensor(data):
res = data.to(DEVICE)
else:
# Handle mlx arrays if passed to torch to_array
if _HAS_MLX and isinstance(data, mx.array):
data = np.array(data)
res = torch.as_tensor(data, device=DEVICE)
if dtype is not None:
# Handle common dtype conversions
if isinstance(dtype, str):
if dtype == "float32": dtype = torch.float32
elif dtype == "float64": dtype = torch.float64
res = res.to(dtype)
return res
# Fallback to numpy if no backend framework is available
res = np.array(data)
if dtype is not None:
res = res.astype(dtype)
return res
def to_device(data):
"""
Move tensors or models to the active device.
If global FRAMEWORK is 'mlx', it still ensures torch models/tensors
are moved to TORCH_DEVICE if they are torch objects.
"""
if isinstance(data, (torch.Tensor, torch.nn.Module)):
return data.to(TORCH_DEVICE)
if FRAMEWORK == "mlx":
# For MLX arrays, it's mostly a no-op
if _HAS_MLX and isinstance(data, mx.array):
return data
# If it's a numpy array, convert to mx.array?
# Actually TrainerMLX handles that.
return data
if FRAMEWORK == "torch":
if isinstance(data, (torch.Tensor, torch.nn.Module)):
return data.to(DEVICE)
if isinstance(data, dict):
return {k: to_device(v) for k, v in data.items()}
if isinstance(data, (list, tuple)):
return type(data)(to_device(v) for v in data)
return data
def to_framework_device(data):
"""Alias for to_device for backward compatibility."""
return to_device(data)