File size: 6,449 Bytes
f24563f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
"""
Embedding layers for the LLM model.
"""

import jax
import jax.numpy as jnp
import flax.linen as nn
from typing import Optional, Tuple, Callable
import math


class TokenEmbedding(nn.Module):
    """
    Token embedding layer.

    Attributes:
        vocab_size: Size of vocabulary
        embed_dim: Embedding dimension
        dtype: Data type for embeddings
    """
    vocab_size: int
    embed_dim: int
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.embedding = self.param(
            'embedding',
            nn.initializers.normal(stddev=0.02),
            (self.vocab_size, self.embed_dim),
            self.dtype
        )

    def __call__(self, input_ids: jnp.ndarray) -> jnp.ndarray:
        """
        Apply token embedding.

        Args:
            input_ids: Token IDs [batch_size, seq_len]

        Returns:
            Token embeddings [batch_size, seq_len, embed_dim]
        """
        return jnp.take(self.embedding, input_ids, axis=0)


class PositionalEmbedding(nn.Module):
    """
    Learned positional embedding layer.

    Attributes:
        max_seq_len: Maximum sequence length
        embed_dim: Embedding dimension
        dtype: Data type for embeddings
    """
    max_seq_len: int
    embed_dim: int
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.embedding = self.param(
            'embedding',
            nn.initializers.normal(stddev=0.02),
            (self.max_seq_len, self.embed_dim),
            self.dtype
        )

    def __call__(self, positions: jnp.ndarray) -> jnp.ndarray:
        """
        Apply positional embedding.

        Args:
            positions: Position indices [batch_size, seq_len]

        Returns:
            Positional embeddings [batch_size, seq_len, embed_dim]
        """
        return jnp.take(self.embedding, positions, axis=0)


class RotaryPositionalEmbedding(nn.Module):
    """
    Rotary Positional Embedding (RoPE) with support for long sequences.

    Attributes:
        dim: Dimension of the embedding
        max_seq_len: Maximum sequence length
        base: Base for frequency computation
        scale: Scaling factor for RoPE frequencies (for longer contexts)
        dtype: Data type for embeddings
        use_dynamic_scaling: Whether to use dynamic scaling for longer contexts
    """
    dim: int
    max_seq_len: int = 32768  # Increased to support longer contexts
    base: int = 10000
    scale: float = 1.0  # Scaling factor for RoPE frequencies
    dtype: jnp.dtype = jnp.float32
    use_dynamic_scaling: bool = True  # Enable dynamic scaling for longer contexts
    original_max_seq_len: int = 4096  # Original max sequence length for scaling

    def setup(self):
        # Apply scaling for longer contexts if enabled
        effective_base = self.base
        if self.use_dynamic_scaling and self.max_seq_len > self.original_max_seq_len:
            # Dynamic NTK-aware scaling for longer contexts
            # This helps maintain the same level of position sensitivity at longer distances
            scaling_factor = math.log(self.max_seq_len / self.original_max_seq_len) / math.log(2)
            effective_base = self.base * (self.scale ** scaling_factor)

        # Compute frequency bands
        freqs = effective_base ** (-jnp.arange(0, self.dim, 2) / self.dim)
        # Compute position encodings
        pos = jnp.arange(self.max_seq_len)
        # Outer product of positions and frequencies
        freqs = jnp.outer(pos, freqs).astype(self.dtype)
        # Cache cos and sin values
        self.cos_cached = jnp.cos(freqs)
        self.sin_cached = jnp.sin(freqs)

    def _rotate_half(self, x: jnp.ndarray) -> jnp.ndarray:
        """Rotate half of the dimensions."""
        x1, x2 = jnp.split(x, 2, axis=-1)
        return jnp.concatenate([-x2, x1], axis=-1)

    def __call__(self, x: jnp.ndarray, positions: Optional[jnp.ndarray] = None) -> jnp.ndarray:
        """
        Apply rotary positional embedding.

        Args:
            x: Input tensor [batch_size, seq_len, ..., dim]
            positions: Optional position indices [batch_size, seq_len]

        Returns:
            Tensor with rotary positional encoding applied
        """
        seq_len = x.shape[1]

        if positions is None:
            positions = jnp.arange(seq_len)

        # Ensure positions are within bounds
        positions = jnp.clip(positions, 0, self.max_seq_len - 1)

        # Get cos and sin values for the positions
        cos = jnp.take(self.cos_cached, positions, axis=0)[:, :seq_len]
        sin = jnp.take(self.sin_cached, positions, axis=0)[:, :seq_len]

        # Reshape for broadcasting
        cos = cos.reshape(cos.shape + (1,) * (x.ndim - 3))
        sin = sin.reshape(sin.shape + (1,) * (x.ndim - 3))

        # Apply rotary embedding
        return x * cos + self._rotate_half(x) * sin


def get_rope_embedding(
    dim: int,
    max_seq_len: int = 32768,
    base: int = 10000,
    scale: float = 1.0,
    use_dynamic_scaling: bool = True,
    original_max_seq_len: int = 4096,
    dtype: jnp.dtype = jnp.float32
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Get rotary positional embedding (RoPE) sin and cos values with support for long sequences.

    Args:
        dim: Dimension of the embedding
        max_seq_len: Maximum sequence length (default: 32768 for long context)
        base: Base for frequency computation
        scale: Scaling factor for RoPE frequencies (for longer contexts)
        use_dynamic_scaling: Whether to use dynamic scaling for longer contexts
        original_max_seq_len: Original max sequence length for scaling
        dtype: Data type for embeddings

    Returns:
        Tuple of (cos, sin) arrays for RoPE
    """
    # Apply scaling for longer contexts if enabled
    effective_base = base
    if use_dynamic_scaling and max_seq_len > original_max_seq_len:
        # Dynamic NTK-aware scaling for longer contexts
        scaling_factor = math.log(max_seq_len / original_max_seq_len) / math.log(2)
        effective_base = base * (scale ** scaling_factor)

    # Compute frequency bands
    freqs = effective_base ** (-jnp.arange(0, dim, 2) / dim)
    # Compute position encodings
    pos = jnp.arange(max_seq_len)
    # Outer product of positions and frequencies
    freqs = jnp.outer(pos, freqs).astype(dtype)
    # Compute cos and sin values
    cos = jnp.cos(freqs)
    sin = jnp.sin(freqs)

    return cos, sin