Vortex-13b-V1 / mps_optimize.py
Zandy-Wandy's picture
Upload Vortex model
5c43f61 verified
"""
MPS optimizations for Vortex model on Apple Silicon.
Uses PyTorch MPS backend with MPS-compatible ops only.
"""
import torch
import torch.nn as nn
from typing import Optional, Dict, Any
def optimize_for_mps(
model: nn.Module,
config: Dict,
use_sdpa: bool = True,
) -> nn.Module:
"""
Apply MPS optimizations to model.
Args:
model: VortexModel
config: Model config
use_sdpa: Use PyTorch scaled dot product attention (MPS compatible)
Returns:
Optimized model
"""
device = torch.device("mps")
# Move to MPS
model = model.to(device)
# Set dtype - MPS supports float32 and float16 (bfloat16 limited)
dtype_str = config.get("dtype", "bfloat16")
if dtype_str == "bfloat16":
# MPS has limited bfloat16 support, use float16
dtype = torch.float16
else:
dtype = torch.float32
model = model.to(dtype)
# Replace Flash Attention with standard SDPA
if use_sdpa:
model = _apply_sdpa(model)
print("Applied PyTorch SDPA for MPS")
return model
def _apply_sdpa(model: nn.Module) -> nn.Module:
"""
Replace custom attention with PyTorch SDPA.
SDPA is optimized for MPS backend.
"""
for name, module in model.named_modules():
if hasattr(module, 'attn') and hasattr(module.attn, 'forward_optimized'):
# Use the SDPA path
original_forward = module.attn.forward
def sdpa_forward(self, x, *args, **kwargs):
return self._standard_attention(x, kwargs.get('attention_mask'))
module.attn.forward = sdpa_forward.__get__(module.attn, type(module.attn))
return model
def get_mps_memory_usage() -> Dict[str, float]:
"""Get current MPS memory usage in GB."""
if not torch.backends.mps.is_available():
return {"error": "MPS not available"}
# MPS doesn't have direct memory query, use unified memory
import psutil
process = psutil.Process()
memory_info = process.memory_info()
return {
"rss_gb": memory_info.rss / 1e9, # Resident set size
"vms_gb": memory_info.vms / 1e9, # Virtual memory size
}
def profile_model_mps(
model: nn.Module,
input_ids: torch.Tensor,
num_warmup: int = 10,
num_runs: int = 50,
) -> Dict[str, float]:
"""
Profile model performance on MPS.
Args:
model: Model to profile
input_ids: Example input
num_warmup: Number of warmup runs
num_runs: Number of profiling runs
Returns:
Dictionary with timing statistics
"""
model.eval()
device = next(model.parameters()).device
input_ids = input_ids.to(device)
# Warmup
with torch.no_grad():
for _ in range(num_warmup):
_ = model(input_ids)
# MPS is async, need to wait
if device.type == "mps":
torch.mps.synchronize()
# Profile
if device.type == "mps":
torch.mps.synchronize()
import time
start = time.time()
with torch.no_grad():
for _ in range(num_runs):
_ = model(input_ids)
if device.type == "mps":
torch.mps.synchronize()
elapsed = time.time() - start
avg_time = elapsed / num_runs
tokens_per_sec = input_ids.shape[1] / avg_time
return {
"avg_time_sec": avg_time,
"tokens_per_sec": tokens_per_sec,
}
def test_mps_optimize():
"""Test MPS optimizations."""
if not torch.backends.mps.is_available():
print("MPS not available, skipping test")
return
from models.vortex_model import VortexModel
from configs.vortex_7b_config import VORTEX_7B_CONFIG
config = VORTEX_7B_CONFIG.copy()
config["d_model"] = 512
config["num_layers"] = 2
config["num_heads"] = 8
config["vocab_size"] = 1000
model = VortexModel(config)
print(f"Model parameters: {model.get_num_params():,}")
# Optimize for MPS
model = optimize_for_mps(model, config, use_sdpa=True)
# Test forward
batch_size = 2
seq_len = 128
input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len)).to("mps")
with torch.no_grad():
output = model(input_ids)
logits = output["logits"]
print(f"Output shape: {logits.shape}")
print("MPS optimize test passed!")
if __name__ == "__main__":
test_mps_optimize()