Kernels
File size: 4,087 Bytes
3261444
 
 
 
35894d1
3261444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d65066c
 
 
 
3261444
b0f46c7
d65066c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3261444
 
 
 
d65066c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3261444
d65066c
 
 
3261444
 
 
 
 
 
 
 
d65066c
 
3261444
d65066c
3261444
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import logging

import torch
import torch.distributed as dist
from optimizer.muon import Muon, get_default_muon_param_groups
from torch.distributed.fsdp import FSDPModule, fully_shard
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import Replicate
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


def load_model(fsdp: bool) -> torch.nn.Module:
    model_name = "Motif-Technologies/Motif-2.6B"
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
    ).bfloat16().cuda()

    random_grads = []
    for param in model.parameters():
        random_grad = torch.randn_like(param,
                                       device=param.device,
                                       dtype=param.dtype)
        random_grads.append(random_grad)

    if fsdp:
        for layer in model.model.layers:
            fully_shard(layer)
            layer.reshard()
        fully_shard(model)
        model.reshard()

    for i, param in enumerate(model.parameters()):
        if isinstance(param.data, DTensor):
            unsharded_grad = DTensor.from_local(
                random_grads[i],
                device_mesh=param.data.device_mesh,
                placements=[Replicate()] * param.data.device_mesh.ndim,
            )
            sharded_grad = unsharded_grad.redistribute(
                device_mesh=param.data.device_mesh,
                placements=param.data.placements)
            param.grad = sharded_grad
        else:
            param.grad = random_grads[i]

    return model


def run_muon(fsdp: bool, qk_clip: bool, seed: int) -> torch.nn.Module:
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    model = load_model(fsdp=fsdp)
    params = get_default_muon_param_groups(model)
    qk_logits = None
    if qk_clip:
        qk_logits = {
            i: torch.rand(model.config.num_attention_heads)
            for i in range(model.config.num_hidden_layers)
        }
    optim = Muon(
        params=params,
        clip_config={
            "q_indices": list(range(model.config.num_attention_heads)),
            "k_indices": list(range(model.config.num_attention_heads)),
            "head_dim":
            model.config.hidden_size // model.config.num_attention_heads,
            "threshold": 0.5
        })
    optim.step(qk_logits=qk_logits)

    return model


def run_case(qk_clip: bool, seed: int = 0):
    parallel_muon_result = run_muon(fsdp=True, qk_clip=qk_clip, seed=seed)
    sequential_muon_result = run_muon(fsdp=False, qk_clip=qk_clip, seed=seed)
    label = f"qk_clip={'ON' if qk_clip else 'OFF'}"
    success = compare_results(parallel_muon_result,
                              sequential_muon_result,
                              label=label)

    return success, label


def test_muon():

    base_result = run_case(qk_clip=False, seed=0)
    clip_result = run_case(qk_clip=True, seed=0)

    for success, label in [base_result, clip_result]:
        if success:
            logger.info(f"[{label}] Models match")


def compare_results(parallel_muon_result: torch.nn.Module,
                    sequential_muon_result: torch.nn.Module,
                    label: str) -> None:
    success = True
    for (name_p, p), (name_s,
                      s) in zip(parallel_muon_result.named_parameters(),
                                sequential_muon_result.named_parameters()):
        p = p.data.full_tensor()
        s = s.data
        # Parallel Muon should exactly match Sequential Muon
        if torch.abs(p - s).max() > 0:
            max_diff_index = torch.argmax(torch.abs(p - s))
            logger.info(f"Models differ at parameter {name_p}")
            success = False

    return success


if __name__ == "__main__":
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
    test_muon()