File size: 15,620 Bytes
6f09125
 
 
 
 
f70ae43
8bc54c9
e3d287d
 
 
 
 
 
6f09125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3d287d
6f09125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3d287d
6f09125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import importlib

try:
    xm = importlib.import_module('torch_xla.core.xla_model')
    xs = importlib.import_module('torch_xla.distributed.spmd.xla_sharding')
except ImportError:
    xm = None
    xs = None


class Rotary3D(nn.Module):
    def __init__(self, dim, base=100):
        super().__init__()
        assert dim % 16 == 0, "Embedding dim must be divisible by 16"

        # Embedding dimensions must align precisely with dim // num_heads
        self.x_dim = (6 * dim) // 16
        self.y_dim = (6 * dim) // 16
        self.t_dim = dim - self.x_dim - self.y_dim

        # Precompute inverse frequencies
        self.register_buffer('inv_freq_x', 1.0 / (base ** (torch.arange(0, self.x_dim, 2).float() / self.x_dim)))
        self.register_buffer('inv_freq_y', 1.0 / (base ** (torch.arange(0, self.y_dim, 2).float() / self.y_dim)))
        self.register_buffer('inv_freq_t', 1.0 / (base ** (torch.arange(0, self.t_dim, 2).float() / self.t_dim)))

    def forward(self, x, pos):
        """
        x: [batch, nh, seq_len, head_dim]
        pos: [batch, seq_len, 3] integer positions along (x, y, t)
        """
        B, nh, T, hs = x.shape
        assert pos.shape[-1] == 3, "Position tensor must have shape [batch, seq_len, 3]"

        # Compute embeddings directly to match `hs`
        dim_total = hs
        assert dim_total % 2 == 0, "head_dim (hs) must be divisible by 2 for rotary embedding."

        # Positional dimensions expanded explicitly
        dtype = self.inv_freq_x.dtype
        pos_x = pos[..., 0].to(dtype)  # [B, T]
        pos_y = pos[..., 1].to(dtype)  # [B, T]
        pos_t = pos[..., 2].to(dtype)  # [B, T]

        # Generate embeddings for x, y, t and combine
        freqs_x = torch.einsum('bt,f -> btf', pos_x, self.inv_freq_x)
        freqs_y = torch.einsum('bt,f -> btf', pos_y, self.inv_freq_y)
        freqs_t = torch.einsum('bt,f -> btf', pos_t, self.inv_freq_t)

        # Concatenate embeddings and match dimensions exactly
        freq_combined = torch.cat([freqs_x, freqs_y, freqs_t], dim=-1)

        # Cos and Sin embedding, reshape to match x exactly
        cos_emb = freq_combined.cos().unsqueeze(1)  # [B, 1, T, hs/2]
        sin_emb = freq_combined.sin().unsqueeze(1)  # [B, 1, T, hs/2]

        # Split embedding dimension for rotation
        x1, x2 = x[..., :hs//2], x[..., hs//2:]

        # Ensure exact dimensional matching
        x_rotated = torch.cat([
            x1 * cos_emb - x2 * sin_emb,
            x1 * sin_emb + x2 * cos_emb
        ], dim=-1)

        return x_rotated


class PSIAttentionLayer(nn.Module):

    def __init__(self, config):

        super().__init__()
        assert config.n_embd % config.n_head == 0

        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        # positional embedding
        self.rope = Rotary3D(config.n_embd // config.n_head)

        # check if we are using causal attention
        if config.attention_mask == "causal":
            self.is_causal = True
        else:
            self.is_causal = False

        # check if GPU Flash Attention is available
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')

        # check if we are running on TPU
        try:
            # Use local import to avoid conflict if global xm is None and to check TPU specifically for this flag
            xm_local = importlib.import_module('torch_xla.core.xla_model')
            self.tpu = True
        except ImportError:
            self.tpu = False

        # Apply XLA sharding for model parallelism
        xla_device_available = False
        if xm is not None:
            try:
                device_kind = xm.xla_device_kind()
                if device_kind is not None:
                    xla_device_available = True
            except RuntimeError:
                pass

    @torch.compiler.disable
    def emplace_kv(self, T, k_cache, v_cache, k, v):
        # torch.compile doesn't play well with this op (5x slowdown)
        # so we insert a graph break and copy eagerly
        k_cache[:,:,-T:].copy_(k)
        v_cache[:,:,-T:].copy_(v)
        return k_cache, v_cache

    def forward(self, x, pos, k_cache=None, v_cache=None, return_kv=False, inplace_kv=False, mask=None):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)

        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # Apply rotary positional embedding
        k = self.rope(k, pos)
        q = self.rope(q, pos)
        
        if inplace_kv and k_cache is not None and v_cache is not None:
            # assign into kv cache in-place
            k, v = self.emplace_kv(T, k_cache, v_cache, k, v)
        else:
            # append cached keys and values with new keys and values
            if k_cache is not None:
                k = torch.cat((k_cache, k), dim=2)
            if v_cache is not None:
                v = torch.cat((v_cache, v), dim=2)

        # Apply attention
        if self.tpu:
            # (1)
            flash_attention = importlib.import_module('torch_xla.experimental.custom_kernel.flash_attention')
            q_norm = q / math.sqrt(k.size(-1))
            y = flash_attention(
                q_norm, k, v, 
                causal=True, partition_spec=('fsdp', None, None, None))
            # (2)
            # y = torch.nn.functional.scaled_dot_product_attention(
            #     q, k, v,
            #     # dropout_p=self.dropout if self.training else 0,
            #     # attn_mask=None if mask is None else mask.to(q.dtype),
            #     is_causal=True
            # )
        elif self.flash:
            # efficient attention using Flash Attention CUDA kernels
            L, S = q.size(-2), k.size(-2)
            is_causal = self.is_causal and mask is None
            # is_causal doesn't work when not square, so replace with a manual mask if needed
            if is_causal and L < S:
                if L > 1:   # if L=1, just use no mask
                    mask = torch.ones(L, S, dtype=q.dtype, device=q.device)
                    mask.masked_fill_(mask.to(torch.bool).triu(S-L+1), float('-inf'))
                is_causal = False

            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.dropout if self.training else 0,
                attn_mask=None if mask is None else mask.to(q.dtype),
                is_causal=is_causal
            )
        else:
            # manual implementation of attention
            att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
            # apply mask, or use causal if default
            if mask is not None:
                att = att + mask
            elif self.is_causal:
                L, S = q.size(-2), k.size(-2)
                mask = torch.ones(1, 1, L, S).triu(S-L+1).to(dtype=torch.bool).to(x.device)
                att.masked_fill_(mask, float('-inf'))
            # upcast to float32 for numerical stability, as per llama implementation
            att = F.softmax(att, dim=-1, dtype=torch.float32).to(q.dtype)
            att = self.attn_dropout(att)
            # multiply attention weights with values to get output
            y = torch.einsum('bnsk,bnkh->bnsh', att, v)

        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.resid_dropout(self.c_proj(y))
        # return key and value caches if requested
        if return_kv:
            return y, k, v

        return y

    def kv_cache_forward(self, x, pos, k_cache=None, v_cache=None):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # Apply rotary positional embedding (before concat)
        k = self.rope(k, pos)
        q = self.rope(q, pos)

        # append cached keys and values with new keys and values
        if k_cache is not None:
            k = torch.cat((k_cache, k), dim=2)
        if v_cache is not None:
            v = torch.cat((v_cache, v), dim=2)

        # manual implementation of attention
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)

        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))

        return y, k, v


class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

        # Apply XLA sharding for model parallelism
        xla_device_available = False
        if xm is not None:
            try:
                device_kind = xm.xla_device_kind()
                if device_kind is not None:
                    xla_device_available = True
            except RuntimeError:
                pass
        
        if xla_device_available and xs is not None and xs.global_mesh() is not None:
            mesh = xs.global_mesh()
            if mesh.mesh_shape[1] > 1: # If the 'model' axis has size > 1
                xs.mark_sharding(self.c_fc.weight, mesh, (1, 0))
                if self.c_fc.bias is not None:
                    xs.mark_sharding(self.c_fc.bias, mesh, (1,))
                print(f"MLP: Applied MP sharding to c_fc {mesh.mesh_shape} spec weight(1,0), bias(1,)")

                xs.mark_sharding(self.c_proj.weight, mesh, (0, 1))
                if self.c_proj.bias is not None:
                    xs.mark_sharding(self.c_proj.bias, mesh, (0,))
                print(f"MLP: Applied MP sharding to c_proj {mesh.mesh_shape} spec weight(0,1), bias(0,)")

    def forward(self, x, spmd_mesh=None):
        
        x = self.c_fc(x)
        x = self.gelu(x)

        if spmd_mesh is not None:
            xs.mark_sharding(x, spmd_mesh,  (('dcn', 'data'), None, 'model'))

        x = self.c_proj(x)
        x = self.dropout(x)

        if spmd_mesh is not None:
            xs.mark_sharding(x, spmd_mesh,  (('dcn', 'data'), None, 'model'))

        return x
    

class RMSNorm(nn.Module):
    """ Root Mean Square Normalization """
    def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-5): # whl
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim)) if weight else None

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        if self.weight is not None:
            return output * self.weight
        return output


class PSIBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = RMSNorm(config.n_embd, bias=config.bias)
        self.attn = PSIAttentionLayer(config)
        self.ln_2 = RMSNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x, pos, k_cache=None, v_cache=None, return_kv=False, inplace_kv=False, spmd_mesh=None, mask=None):
        # If we are given a key and value cache, we will use the pre-computed values to minimize
        # the computation cost
        if return_kv:
            # Pass the key and value cache to the attention layer, obtain new key and value caches
            x_attn, k, v = self.attn(self.ln_1(x), pos, k_cache=k_cache, v_cache=v_cache,
                                     return_kv=True, inplace_kv=inplace_kv, mask=mask)
            x = x + x_attn
            x = x + self.mlp(self.ln_2(x))
            return x, k, v
        # Else we proceed with the regular forward pass
        x = x + self.attn(self.ln_1(x), pos, k_cache=k_cache, v_cache=v_cache, inplace_kv=inplace_kv, mask=mask)
        x = x + self.mlp(self.ln_2(x))
        return x


class PartitionedEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, partition_size=65536):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.partition_size = partition_size
        self.num_partitions = (num_embeddings + partition_size - 1) // partition_size

        self.embedding_layers = nn.ModuleList()
        for i in range(self.num_partitions):
            start_idx = i * self.partition_size
            end_idx = min(start_idx + self.partition_size, num_embeddings)
            vocab_size = end_idx - start_idx
            self.embedding_layers.append(nn.Embedding(vocab_size, embedding_dim))

    def forward(self, input_ids):
        partition_ids = input_ids // self.partition_size
        relative_ids = input_ids % self.partition_size
        
        output = torch.zeros(*input_ids.shape, self.embedding_dim, device=input_ids.device, dtype=self.embedding_layers[0].weight.dtype)

        for i in range(self.num_partitions):
            mask = (partition_ids == i)
            if mask.any():
                partition_input_ids = relative_ids[mask]
                embedded = self.embedding_layers[i](partition_input_ids)
                output[mask] = embedded
        
        return output


class PartitionedLinear(nn.Module):
    def __init__(self, in_features, out_features, partition_size=65536, bias=False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.partition_size = partition_size
        self.num_partitions = (out_features + partition_size - 1) // partition_size
        
        self.linear_layers = nn.ModuleList()
        for i in range(self.num_partitions):
            start_idx = i * self.partition_size
            end_idx = min(start_idx + self.partition_size, out_features)
            output_partition_size = end_idx - start_idx
            self.linear_layers.append(nn.Linear(in_features, output_partition_size, bias=bias))

    def forward(self, input):
        outputs = [layer(input) for layer in self.linear_layers]
        return torch.cat(outputs, dim=-1)