""" MPS optimizations for Vortex model on Apple Silicon. Uses PyTorch MPS backend with MPS-compatible ops only. """ import torch import torch.nn as nn from typing import Optional, Dict, Any def optimize_for_mps( model: nn.Module, config: Dict, use_sdpa: bool = True, ) -> nn.Module: """ Apply MPS optimizations to model. Args: model: VortexModel config: Model config use_sdpa: Use PyTorch scaled dot product attention (MPS compatible) Returns: Optimized model """ device = torch.device("mps") # Move to MPS model = model.to(device) # Set dtype - MPS supports float32 and float16 (bfloat16 limited) dtype_str = config.get("dtype", "bfloat16") if dtype_str == "bfloat16": # MPS has limited bfloat16 support, use float16 dtype = torch.float16 else: dtype = torch.float32 model = model.to(dtype) # Replace Flash Attention with standard SDPA if use_sdpa: model = _apply_sdpa(model) print("Applied PyTorch SDPA for MPS") return model def _apply_sdpa(model: nn.Module) -> nn.Module: """ Replace custom attention with PyTorch SDPA. SDPA is optimized for MPS backend. """ for name, module in model.named_modules(): if hasattr(module, 'attn') and hasattr(module.attn, 'forward_optimized'): # Use the SDPA path original_forward = module.attn.forward def sdpa_forward(self, x, *args, **kwargs): return self._standard_attention(x, kwargs.get('attention_mask')) module.attn.forward = sdpa_forward.__get__(module.attn, type(module.attn)) return model def get_mps_memory_usage() -> Dict[str, float]: """Get current MPS memory usage in GB.""" if not torch.backends.mps.is_available(): return {"error": "MPS not available"} # MPS doesn't have direct memory query, use unified memory import psutil process = psutil.Process() memory_info = process.memory_info() return { "rss_gb": memory_info.rss / 1e9, # Resident set size "vms_gb": memory_info.vms / 1e9, # Virtual memory size } def profile_model_mps( model: nn.Module, input_ids: torch.Tensor, num_warmup: int = 10, num_runs: int = 50, ) -> Dict[str, float]: """ Profile model performance on MPS. Args: model: Model to profile input_ids: Example input num_warmup: Number of warmup runs num_runs: Number of profiling runs Returns: Dictionary with timing statistics """ model.eval() device = next(model.parameters()).device input_ids = input_ids.to(device) # Warmup with torch.no_grad(): for _ in range(num_warmup): _ = model(input_ids) # MPS is async, need to wait if device.type == "mps": torch.mps.synchronize() # Profile if device.type == "mps": torch.mps.synchronize() import time start = time.time() with torch.no_grad(): for _ in range(num_runs): _ = model(input_ids) if device.type == "mps": torch.mps.synchronize() elapsed = time.time() - start avg_time = elapsed / num_runs tokens_per_sec = input_ids.shape[1] / avg_time return { "avg_time_sec": avg_time, "tokens_per_sec": tokens_per_sec, } def test_mps_optimize(): """Test MPS optimizations.""" if not torch.backends.mps.is_available(): print("MPS not available, skipping test") return from models.vortex_model import VortexModel from configs.vortex_7b_config import VORTEX_7B_CONFIG config = VORTEX_7B_CONFIG.copy() config["d_model"] = 512 config["num_layers"] = 2 config["num_heads"] = 8 config["vocab_size"] = 1000 model = VortexModel(config) print(f"Model parameters: {model.get_num_params():,}") # Optimize for MPS model = optimize_for_mps(model, config, use_sdpa=True) # Test forward batch_size = 2 seq_len = 128 input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len)).to("mps") with torch.no_grad(): output = model(input_ids) logits = output["logits"] print(f"Output shape: {logits.shape}") print("MPS optimize test passed!") if __name__ == "__main__": test_mps_optimize()