File size: 4,589 Bytes
5c43f61 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | """
MPS optimizations for Vortex model on Apple Silicon.
Uses PyTorch MPS backend with MPS-compatible ops only.
"""
import torch
import torch.nn as nn
from typing import Optional, Dict, Any
def optimize_for_mps(
model: nn.Module,
config: Dict,
use_sdpa: bool = True,
) -> nn.Module:
"""
Apply MPS optimizations to model.
Args:
model: VortexModel
config: Model config
use_sdpa: Use PyTorch scaled dot product attention (MPS compatible)
Returns:
Optimized model
"""
device = torch.device("mps")
# Move to MPS
model = model.to(device)
# Set dtype - MPS supports float32 and float16 (bfloat16 limited)
dtype_str = config.get("dtype", "bfloat16")
if dtype_str == "bfloat16":
# MPS has limited bfloat16 support, use float16
dtype = torch.float16
else:
dtype = torch.float32
model = model.to(dtype)
# Replace Flash Attention with standard SDPA
if use_sdpa:
model = _apply_sdpa(model)
print("Applied PyTorch SDPA for MPS")
return model
def _apply_sdpa(model: nn.Module) -> nn.Module:
"""
Replace custom attention with PyTorch SDPA.
SDPA is optimized for MPS backend.
"""
for name, module in model.named_modules():
if hasattr(module, 'attn') and hasattr(module.attn, 'forward_optimized'):
# Use the SDPA path
original_forward = module.attn.forward
def sdpa_forward(self, x, *args, **kwargs):
return self._standard_attention(x, kwargs.get('attention_mask'))
module.attn.forward = sdpa_forward.__get__(module.attn, type(module.attn))
return model
def get_mps_memory_usage() -> Dict[str, float]:
"""Get current MPS memory usage in GB."""
if not torch.backends.mps.is_available():
return {"error": "MPS not available"}
# MPS doesn't have direct memory query, use unified memory
import psutil
process = psutil.Process()
memory_info = process.memory_info()
return {
"rss_gb": memory_info.rss / 1e9, # Resident set size
"vms_gb": memory_info.vms / 1e9, # Virtual memory size
}
def profile_model_mps(
model: nn.Module,
input_ids: torch.Tensor,
num_warmup: int = 10,
num_runs: int = 50,
) -> Dict[str, float]:
"""
Profile model performance on MPS.
Args:
model: Model to profile
input_ids: Example input
num_warmup: Number of warmup runs
num_runs: Number of profiling runs
Returns:
Dictionary with timing statistics
"""
model.eval()
device = next(model.parameters()).device
input_ids = input_ids.to(device)
# Warmup
with torch.no_grad():
for _ in range(num_warmup):
_ = model(input_ids)
# MPS is async, need to wait
if device.type == "mps":
torch.mps.synchronize()
# Profile
if device.type == "mps":
torch.mps.synchronize()
import time
start = time.time()
with torch.no_grad():
for _ in range(num_runs):
_ = model(input_ids)
if device.type == "mps":
torch.mps.synchronize()
elapsed = time.time() - start
avg_time = elapsed / num_runs
tokens_per_sec = input_ids.shape[1] / avg_time
return {
"avg_time_sec": avg_time,
"tokens_per_sec": tokens_per_sec,
}
def test_mps_optimize():
"""Test MPS optimizations."""
if not torch.backends.mps.is_available():
print("MPS not available, skipping test")
return
from models.vortex_model import VortexModel
from configs.vortex_7b_config import VORTEX_7B_CONFIG
config = VORTEX_7B_CONFIG.copy()
config["d_model"] = 512
config["num_layers"] = 2
config["num_heads"] = 8
config["vocab_size"] = 1000
model = VortexModel(config)
print(f"Model parameters: {model.get_num_params():,}")
# Optimize for MPS
model = optimize_for_mps(model, config, use_sdpa=True)
# Test forward
batch_size = 2
seq_len = 128
input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len)).to("mps")
with torch.no_grad():
output = model(input_ids)
logits = output["logits"]
print(f"Output shape: {logits.shape}")
print("MPS optimize test passed!")
if __name__ == "__main__":
test_mps_optimize()
|