File size: 21,306 Bytes
1b703d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
from __future__ import annotations

import math
from collections.abc import Callable

import torch
from torch import nn


class Rope1D(nn.Module):
    """
    Rotary Position Embedding (RoPE) 1D.

    Based on the reference LLaMA implementation (Hugging Face
    `modeling_llama.py`), adapted to this codebase without behavior changes.

    - dim: per-head dimension
    - max_position_embeddings: length used to precompute cached cos/sin (not required
      by forward)
    - base: RoPE base theta

    Forward expects:
      - x: (B, H, T, D)
      - position_ids: (B, T) integer positions
    Returns:
      - cos, sin: (B, T, D)
    """

    inv_freq: torch.Tensor
    _cos_cached: torch.Tensor
    _sin_cached: torch.Tensor

    def __init__(
        self,
        dim: int,
        max_position_embeddings: int = 2048,
        base: float = 10000.0,
        device: torch.device | None = None,
        scaling_factor: float = 1.0,
    ) -> None:
        super().__init__()
        if dim % 2 != 0:
            raise AssertionError("head_dim must be even for RoPE")
        self.scaling_factor: float = float(scaling_factor)
        self.dim: int = int(dim)
        self.max_position_embeddings: int = int(max_position_embeddings)
        self.base: float = float(base)
        inv_freq = self._build_inv_freq(device=device)
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Cached cos/sin (not used in application, but kept for parity with reference)
        self.max_seq_len_cached: int = self.max_position_embeddings
        cos_cached, sin_cached = self._build_cached_trig(device=device)
        self.register_buffer("_cos_cached", cos_cached, persistent=False)
        self.register_buffer("_sin_cached", sin_cached, persistent=False)

    def _build_inv_freq(self, *, device: torch.device | None) -> torch.Tensor:
        """Return the RoPE inverse-frequency vector in float32."""

        return 1.0 / (
            self.base
            ** (
                torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
                / float(self.dim)
            )
        )

    def _build_cached_trig(
        self, *, device: torch.device | None
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Return cached RoPE trig tensors in float32."""

        inv_freq = self._build_inv_freq(device=device)
        t = torch.arange(
            self.max_seq_len_cached,
            device=device,
            dtype=torch.float32,
        )
        t = t / self.scaling_factor
        freqs = torch.outer(t, inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb.cos(), emb.sin()

    def _apply(
        self,
        fn: Callable[[torch.Tensor], torch.Tensor],
        recurse: bool = True,
    ) -> Rope1D:
        """Apply module moves/casts while preserving fp32 RoPE buffers."""

        out = super()._apply(fn, recurse=recurse)
        with torch.no_grad():
            device = self.inv_freq.device
            self.inv_freq.data = self._build_inv_freq(device=device)
            cos_cached, sin_cached = self._build_cached_trig(device=device)
            self._cos_cached.data = cos_cached
            self._sin_cached.data = sin_cached
        return out

    @torch.no_grad()
    def forward(
        self, x: torch.Tensor, position_ids: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        inv_freq_tensor = self._build_inv_freq(device=x.device)
        inv_freq_expanded = (
            inv_freq_tensor[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        )
        position_ids_expanded = position_ids[:, None, :].float() / self.scaling_factor
        device_type = x.device.type
        device_type = (
            device_type
            if isinstance(device_type, str) and device_type != "mps"
            else "cpu"
        )
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (
                inv_freq_expanded.float() @ position_ids_expanded.float()
            ).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def rotate_half(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def rotate_half_adjacent(x: torch.Tensor) -> torch.Tensor:
    """Rotate consecutive pairs in the last dimension.

    This matches the common EVA-02 / SpeedrunDiT RoPE convention where the last
    dimension is interpreted as pairs ``(x0, x1), (x2, x3), ...``.
    """
    if x.shape[-1] % 2 != 0:
        raise ValueError("rotate_half_adjacent requires an even last dimension")
    x_pairs = x.reshape(*x.shape[:-1], x.shape[-1] // 2, 2)
    x1 = x_pairs[..., 0]
    x2 = x_pairs[..., 1]
    return torch.stack((-x2, x1), dim=-1).reshape_as(x)


def apply_rotary_pos_emb(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    *,
    unsqueeze_dim: int = 1,
) -> tuple[torch.Tensor, torch.Tensor]:
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class LearnableRoPE2D(nn.Module):
    r"""
    Learnable mixed 2D RoPE with axial RoPE2D-compatible initialization.

    - Learnable frequency banks for X and Y.
    - Frequencies can be shared across groups of attention heads (see
      ``rope_param_dim``).
    - Angle per pair: theta = x * fx[g, i] + y * fy[g, i]
    - Initialization matches the axial RoPE2D parameterization used by DiTTrunk
      for ``ROPE_2D_AXIAL_FREQ_AWARE`` (AxialRoPE2DConfig(base=100, dim_layout=HALF_SPLIT)):
        - Angle multiplier ``2π``.
        - Period base ``100`` (DINOv3-style), applied per-axis.
      Each head group starts identically (deterministic init) so the learnable
      variant is functionally identical to axial RoPE2D at step 0.
    - Rotation is implemented with real-valued sin/cos to avoid complex tensors
      (torch.compile/inductor cannot codegen complex dtypes).

    Shapes:
    - Expects q,k of shape (B, H, T, D) with D % 4 == 0.
    - Positions xy: (T, 2) or (B, T, 2), any real dtype (cast to float32).
    - Parameter `freqs`: (2, G, D//2) in float32; index 0 = x, 1 = y.

    Head grouping / parameter budget
    -------------------------------
    ``rope_param_dim`` controls the total number of learned RoPE frequency
    parameters (scalars) for this module.

    Let:
      - ``head_dim = D`` (per-head width)
      - ``num_heads = H``
      - ``rope_param_dim = P``

    Then the module uses:
      - ``num_groups = G = P // D``
      - ``heads_per_group = H // G``

    This is fail-fast: ``P`` must be divisible by ``D`` and ``H`` must be
    divisible by ``G``. When ``rope_param_dim`` is None (default), the module
    uses the classic per-head parameterization with ``P = H * D``.
    """

    def __init__(
        self,
        head_dim: int,
        *,
        num_heads: int,
        rope_param_dim: int | None = None,
        rope_base: float = 100.0,
        angle_multiplier: float = 2.0 * float(math.pi),
        learnable: bool = True,
        persist_buffers: bool = True,
    ) -> None:
        super().__init__()
        if head_dim % 4 != 0:
            raise AssertionError("head_dim must be divisible by 4 for mixed 2D RoPE")
        self.head_dim: int = int(head_dim)
        # Avoid naming collisions with nn.Module.half() (dtype casting helper).
        self.half_dim: int = self.head_dim // 2
        self.num_heads: int = int(num_heads)
        effective_param_dim = (
            int(rope_param_dim)
            if rope_param_dim is not None
            else self.num_heads * self.head_dim
        )
        if effective_param_dim <= 0:
            raise ValueError("rope_param_dim must be positive for LearnableRoPE2D")
        self.rope_param_dim: int = int(effective_param_dim)
        self._learnable: bool = bool(learnable)
        theta = float(rope_base)
        mult = float(angle_multiplier)
        if not math.isfinite(theta) or theta <= 0.0:
            raise ValueError("rope_base must be finite and > 0 for LearnableRoPE2D")
        if not math.isfinite(mult) or mult <= 0.0:
            raise ValueError(
                "angle_multiplier must be finite and > 0 for LearnableRoPE2D"
            )

        if self.rope_param_dim % self.head_dim != 0:
            raise ValueError(
                "rope_param_dim must be divisible by head_dim for LearnableRoPE2D "
                f"(got rope_param_dim={self.rope_param_dim}, head_dim={self.head_dim})"
            )
        self.num_groups: int = self.rope_param_dim // self.head_dim
        if self.num_groups <= 0:
            raise RuntimeError("num_groups must be positive for LearnableRoPE2D")
        if self.num_heads % self.num_groups != 0:
            raise ValueError(
                "num_heads must be divisible by (rope_param_dim / head_dim) for LearnableRoPE2D "
                f"(got num_heads={self.num_heads}, num_groups={self.num_groups}, "
                f"rope_param_dim={self.rope_param_dim}, head_dim={self.head_dim})"
            )
        self.heads_per_group: int = self.num_heads // self.num_groups
        if self.heads_per_group <= 0:
            raise RuntimeError("heads_per_group must be positive for LearnableRoPE2D")

        # Axial-compatible deterministic init:
        # - periods match AxialRoPE2DConfig(base=100, dim_layout=HALF_SPLIT)
        # - angle = 2π * coord / period
        qtr = self.head_dim // 4
        exponents = (
            2.0
            * torch.arange(int(qtr), dtype=torch.float32)
            / float(self.head_dim // 2)
        )
        periods = torch.tensor(theta, dtype=torch.float32) ** exponents  # [qtr]
        axis_freqs = (mult / periods).to(dtype=torch.float32)  # [qtr]

        zeros = torch.zeros_like(axis_freqs)
        # Match AxialRoPE2D(HALF_SPLIT) flatten order: [y-axis, x-axis].
        # Our xy columns are (x, y), so:
        # - x contributes to the second quarter (x-axis part)
        # - y contributes to the first quarter (y-axis part)
        fx_half = torch.cat((zeros, axis_freqs), dim=0)  # [half_dim]
        fy_half = torch.cat((axis_freqs, zeros), dim=0)  # [half_dim]

        freqs_x = fx_half.expand(int(self.num_groups), -1).clone()
        freqs_y = fy_half.expand(int(self.num_groups), -1).clone()
        freqs = torch.stack([freqs_x, freqs_y], dim=0)  # (2, G, half)
        if self._learnable:
            self.freqs = nn.Parameter(freqs, requires_grad=True)
        else:
            self.register_buffer("freqs", freqs, persistent=persist_buffers)

    def _apply(
        self,
        fn: Callable[[torch.Tensor], torch.Tensor],
        recurse: bool = True,
    ) -> LearnableRoPE2D:
        """Apply module moves/casts while preserving fp32 frequency tensors."""

        out = super()._apply(fn, recurse=recurse)
        with torch.no_grad():
            self.freqs.data = self.freqs.data.to(dtype=torch.float32)
        return out

    def _apply_rotary_from_trig(
        self,
        x: torch.Tensor,
        *,
        sin: torch.Tensor,
        cos: torch.Tensor,
    ) -> torch.Tensor:
        """Rotate Q/K using precomputed grouped sin/cos buffers (HALF_SPLIT layout).

        This matches AxialRoPE2DConfig(dim_layout=HALF_SPLIT) rotation and keeps
        the learnable variant identical at initialization when combined with
        axial-compatible frequency init.

        Args:
            x: Tensor shaped ``(B, H, T, D)``.
            sin: Sin tensor shaped ``(G, T, D//2)`` or ``(B, G, T, D//2)``.
            cos: Cos tensor shaped ``(G, T, D//2)`` or ``(B, G, T, D//2)``.

        Returns:
            Tensor with the same shape/dtype/device as ``x``.
        """
        if x.dim() != 4:
            raise ValueError("x must be shaped (B, H, T, D)")
        B, H, T, D = x.shape
        if self.num_heads != int(H):
            raise ValueError("num_heads mismatch for LearnableRoPE2D")
        if self.head_dim != int(D):
            raise ValueError("head_dim mismatch for LearnableRoPE2D")

        if sin.dim() == 3 and cos.dim() == 3:
            sin = sin.unsqueeze(0)
            cos = cos.unsqueeze(0)
        if sin.dim() != 4 or cos.dim() != 4:
            raise RuntimeError("Unexpected sin/cos rank for LearnableRoPE2D")
        if int(D) % 2 != 0:
            raise RuntimeError("LearnableRoPE2D requires even head_dim for HALF_SPLIT")
        half = int(D) // 2
        if int(sin.shape[-1]) != half or int(cos.shape[-1]) != half:
            raise RuntimeError(
                "LearnableRoPE2D expected sin/cos last dim == head_dim//2 "
                f"(got sin={tuple(sin.shape)}, cos={tuple(cos.shape)}, head_dim={int(D)})"
            )

        sin = sin[:, :, None, :, :]  # [B, G, 1, T, half]
        cos = cos[:, :, None, :, :]  # [B, G, 1, T, half]

        grouped = x.reshape(
            int(B),
            int(self.num_groups),
            int(self.heads_per_group),
            int(T),
            int(D),
        )
        x1 = grouped[..., :half]
        x2 = grouped[..., half:]
        out1 = x1 * cos - x2 * sin
        out2 = x2 * cos + x1 * sin
        out = torch.cat((out1, out2), dim=-1).reshape(int(B), int(H), int(T), int(D))
        return out.to(dtype=x.dtype)

    def _compute_mixed_cis(self, xy: torch.Tensor) -> torch.Tensor:
        # Returns complex cis angles with shape (G, T, half) or (B, G, T, half)
        if xy.dim() == 2:
            # (T, 2) -> (G, T, half)
            t_x = xy[:, 0].to(dtype=torch.float32)
            t_y = xy[:, 1].to(dtype=torch.float32)
            with torch.autocast(device_type=t_x.device.type, enabled=False):
                # Memory notes:
                # - Avoid materializing both fx and fy; accumulate in-place into angles.
                # - Avoid torch.ones_like(angles) (full-size allocation); a scalar
                #   magnitude broadcasts in torch.polar.
                angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze(
                    0
                )  # (T, G, half)
                angles.add_(
                    t_y.unsqueeze(-1).unsqueeze(-1) * self.freqs[1].unsqueeze(0)
                )
                angles = angles.permute(1, 0, 2)  # (G, T, half)
                cis = torch.polar(
                    torch.ones((), device=angles.device, dtype=angles.dtype), angles
                )
            return cis
        elif xy.dim() == 3:
            # (B, T, 2) -> (B, G, T, half)
            t_x = xy[..., 0].to(dtype=torch.float32)
            t_y = xy[..., 1].to(dtype=torch.float32)
            with torch.autocast(device_type=t_x.device.type, enabled=False):
                angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze(
                    0
                ).unsqueeze(0)
                angles.add_(
                    t_y.unsqueeze(-1).unsqueeze(-1)
                    * self.freqs[1].unsqueeze(0).unsqueeze(0)
                )
                angles = angles.permute(0, 2, 1, 3)  # (B, G, T, half)
                cis = torch.polar(
                    torch.ones((), device=angles.device, dtype=angles.dtype), angles
                )
            return cis
        else:
            raise ValueError("xy must have shape (T,2) or (B,T,2)")

    def _compute_mixed_angles(self, xy: torch.Tensor) -> torch.Tensor:
        """Return mixed RoPE2D angles without applying cis/polar.

        Args:
            xy: XY positions shaped ``(T, 2)`` or ``(B, T, 2)``.

        Returns:
            Float tensor of angles shaped ``(G, T, half)`` or ``(B, G, T, half)``.
        """
        if xy.dim() == 2:
            t_x = xy[:, 0].to(dtype=torch.float32)
            t_y = xy[:, 1].to(dtype=torch.float32)
            with torch.autocast(device_type=t_x.device.type, enabled=False):
                angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze(0)
                angles.add_(
                    t_y.unsqueeze(-1).unsqueeze(-1) * self.freqs[1].unsqueeze(0)
                )
                return angles.permute(1, 0, 2)
        if xy.dim() == 3:
            t_x = xy[..., 0].to(dtype=torch.float32)
            t_y = xy[..., 1].to(dtype=torch.float32)
            with torch.autocast(device_type=t_x.device.type, enabled=False):
                angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze(
                    0
                ).unsqueeze(0)
                angles.add_(
                    t_y.unsqueeze(-1).unsqueeze(-1)
                    * self.freqs[1].unsqueeze(0).unsqueeze(0)
                )
                return angles.permute(0, 2, 1, 3)
        raise ValueError("xy must have shape (T,2) or (B,T,2)")

    def _cos_sin_half_from_xy(
        self,
        xy: torch.Tensor,
        *,
        device: torch.device | None = None,
        out_dtype: torch.dtype | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Helper used in tests to build real-valued cos/sin tensors.
        cis = self._compute_mixed_cis(xy.to(device=device) if device else xy)
        # Convert complex cis to cos/sin (real/imag) with matching shapes
        if cis.is_complex():
            cos_h = cis.real
            sin_h = cis.imag
        else:
            # Should not happen; torch.polar returns complex64/128
            raise RuntimeError("Expected complex cis tensor from polar")
        if out_dtype is not None:
            cos_h = cos_h.to(dtype=out_dtype)
            sin_h = sin_h.to(dtype=out_dtype)
        return cos_h, sin_h

    def _cos_sin_from_xy(
        self,
        xy: torch.Tensor,
        *,
        device: torch.device | None = None,
        out_dtype: torch.dtype | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        cos_h, sin_h = self._cos_sin_half_from_xy(
            xy, device=device, out_dtype=out_dtype
        )
        emb_cos = torch.cat((cos_h, cos_h), dim=-1)
        emb_sin = torch.cat((sin_h, sin_h), dim=-1)
        return emb_cos, emb_sin

    def rotate_qk(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        xy: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if q.dim() != 4 or k.dim() != 4:
            raise ValueError("q,k must be shaped (B,H,T,D)")
        _, H, _, D = q.shape
        if self.num_heads != H:
            raise ValueError("num_heads mismatch for LearnableRoPE2D")
        if self.head_dim != D:
            raise ValueError("head_dim mismatch for LearnableRoPE2D")
        if D % 4 != 0:
            raise AssertionError("head_dim must be divisible by 4 for mixed 2D RoPE")

        # Use real-valued sin/cos rotation to keep torch.compile/inductor on the
        # fast path (inductor cannot codegen complex tensors).
        angles = self._compute_mixed_angles(xy.to(device=q.device))
        sin = torch.sin(angles)
        cos = torch.cos(angles)
        q_out = self._apply_rotary_from_trig(q, sin=sin, cos=cos)
        k_out = self._apply_rotary_from_trig(k, sin=sin, cos=cos)
        return q_out, k_out

    def rotate_qk_with_dilation(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        *,
        xy: torch.Tensor,
        scales: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Rotate Q/K using mixed 2D RoPE with per-sample isotropic dilation.

        This implements dilation by scaling the RoPE angle, i.e.
        ``theta_dilated = scale * theta_base`` where ``theta_base`` comes from the
        undilated XY coordinates.

        Args:
            q: Query tensor shaped ``(B, H, T, D)``.
            k: Key tensor shaped ``(B, H, T, D)``.
            xy: Base XY coordinates shaped ``(T, 2)`` or ``(B, T, 2)``.
            scales: Per-sample dilation scales shaped ``(B,)``.

        Raises:
            ValueError: If shapes are inconsistent or scales are not 1D.
        """
        if q.dim() != 4 or k.dim() != 4:
            raise ValueError("q,k must be shaped (B,H,T,D)")
        B, H, T, D = q.shape
        if self.num_heads != H:
            raise ValueError("num_heads mismatch for LearnableRoPE2D")
        if self.head_dim != D:
            raise ValueError("head_dim mismatch for LearnableRoPE2D")
        if scales.dim() != 1 or scales.shape[0] != B:
            raise ValueError("scales must have shape (B,) matching q batch size")
        if xy.dim() == 2 and xy.shape[0] != T:
            raise ValueError("xy length must match q sequence length")
        if xy.dim() == 3 and (xy.shape[0] != B or xy.shape[1] != T):
            raise ValueError("xy must have shape (B,T,2) matching q batch/sequence")
        if xy.shape[-1] != 2:
            raise ValueError("xy must have last dimension 2")

        angles = self._compute_mixed_angles(xy.to(device=q.device))
        angles = angles * scales.to(device=q.device, dtype=torch.float32).view(
            B, 1, 1, 1
        )
        sin = torch.sin(angles)
        cos = torch.cos(angles)
        q_out = self._apply_rotary_from_trig(q, sin=sin, cos=cos)
        k_out = self._apply_rotary_from_trig(k, sin=sin, cos=cos)
        return q_out, k_out