File size: 4,589 Bytes
5c43f61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""

MPS optimizations for Vortex model on Apple Silicon.

Uses PyTorch MPS backend with MPS-compatible ops only.

"""

import torch
import torch.nn as nn
from typing import Optional, Dict, Any


def optimize_for_mps(

    model: nn.Module,

    config: Dict,

    use_sdpa: bool = True,

) -> nn.Module:
    """

    Apply MPS optimizations to model.



    Args:

        model: VortexModel

        config: Model config

        use_sdpa: Use PyTorch scaled dot product attention (MPS compatible)



    Returns:

        Optimized model

    """
    device = torch.device("mps")

    # Move to MPS
    model = model.to(device)

    # Set dtype - MPS supports float32 and float16 (bfloat16 limited)
    dtype_str = config.get("dtype", "bfloat16")
    if dtype_str == "bfloat16":
        # MPS has limited bfloat16 support, use float16
        dtype = torch.float16
    else:
        dtype = torch.float32

    model = model.to(dtype)

    # Replace Flash Attention with standard SDPA
    if use_sdpa:
        model = _apply_sdpa(model)
        print("Applied PyTorch SDPA for MPS")

    return model


def _apply_sdpa(model: nn.Module) -> nn.Module:
    """

    Replace custom attention with PyTorch SDPA.

    SDPA is optimized for MPS backend.

    """
    for name, module in model.named_modules():
        if hasattr(module, 'attn') and hasattr(module.attn, 'forward_optimized'):
            # Use the SDPA path
            original_forward = module.attn.forward

            def sdpa_forward(self, x, *args, **kwargs):
                return self._standard_attention(x, kwargs.get('attention_mask'))

            module.attn.forward = sdpa_forward.__get__(module.attn, type(module.attn))

    return model


def get_mps_memory_usage() -> Dict[str, float]:
    """Get current MPS memory usage in GB."""
    if not torch.backends.mps.is_available():
        return {"error": "MPS not available"}

    # MPS doesn't have direct memory query, use unified memory
    import psutil
    process = psutil.Process()
    memory_info = process.memory_info()

    return {
        "rss_gb": memory_info.rss / 1e9,  # Resident set size
        "vms_gb": memory_info.vms / 1e9,  # Virtual memory size
    }


def profile_model_mps(

    model: nn.Module,

    input_ids: torch.Tensor,

    num_warmup: int = 10,

    num_runs: int = 50,

) -> Dict[str, float]:
    """

    Profile model performance on MPS.



    Args:

        model: Model to profile

        input_ids: Example input

        num_warmup: Number of warmup runs

        num_runs: Number of profiling runs



    Returns:

        Dictionary with timing statistics

    """
    model.eval()
    device = next(model.parameters()).device
    input_ids = input_ids.to(device)

    # Warmup
    with torch.no_grad():
        for _ in range(num_warmup):
            _ = model(input_ids)
            # MPS is async, need to wait
            if device.type == "mps":
                torch.mps.synchronize()

    # Profile
    if device.type == "mps":
        torch.mps.synchronize()
    import time
    start = time.time()

    with torch.no_grad():
        for _ in range(num_runs):
            _ = model(input_ids)
            if device.type == "mps":
                torch.mps.synchronize()

    elapsed = time.time() - start

    avg_time = elapsed / num_runs
    tokens_per_sec = input_ids.shape[1] / avg_time

    return {
        "avg_time_sec": avg_time,
        "tokens_per_sec": tokens_per_sec,
    }


def test_mps_optimize():
    """Test MPS optimizations."""
    if not torch.backends.mps.is_available():
        print("MPS not available, skipping test")
        return

    from models.vortex_model import VortexModel
    from configs.vortex_7b_config import VORTEX_7B_CONFIG

    config = VORTEX_7B_CONFIG.copy()
    config["d_model"] = 512
    config["num_layers"] = 2
    config["num_heads"] = 8
    config["vocab_size"] = 1000

    model = VortexModel(config)
    print(f"Model parameters: {model.get_num_params():,}")

    # Optimize for MPS
    model = optimize_for_mps(model, config, use_sdpa=True)

    # Test forward
    batch_size = 2
    seq_len = 128
    input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len)).to("mps")

    with torch.no_grad():
        output = model(input_ids)
        logits = output["logits"]

    print(f"Output shape: {logits.shape}")
    print("MPS optimize test passed!")


if __name__ == "__main__":
    test_mps_optimize()