| """ |
| Geometric Attention - CUDA加速版本 (支持FP16) |
| """ |
|
|
| import math |
| import torch |
| from einops import rearrange |
| from typing import Optional |
|
|
| |
| try: |
| from forgetting_transformer.ops.geometric_attention.cuda_interface import ( |
| load_extension, |
| geometric_attention_activation, |
| ) |
| load_extension() |
| HAS_CUDA = True |
| print("✅ Using CUDA geometric attention (with FP16 support)") |
| except Exception as e: |
| HAS_CUDA = False |
| print(f"⚠️ CUDA not available: {e}") |
|
|
|
|
| def geometric_attention_cuda( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| *, |
| head_first: bool = False, |
| seq_start: Optional[torch.Tensor] = None, |
| sm_scale: Optional[float] = None, |
| normalize: bool = True, |
| ) -> torch.Tensor: |
| if not HAS_CUDA: |
| raise RuntimeError("CUDA not available") |
| |
| |
| original_dtype = q.dtype |
| needs_cast = original_dtype == torch.float16 |
| |
| |
| if needs_cast: |
| q = q.float() |
| k = k.float() |
| v = v.float() |
| |
| |
| if not head_first: |
| q = rearrange(q, "b t h d -> b h t d") |
| k = rearrange(k, "b t h d -> b h t d") |
| v = rearrange(v, "b t h d -> b h t d") |
| |
| B, H, T_q, D = q.shape |
| |
| if sm_scale is None: |
| sm_scale = 1.0 / math.sqrt(D) |
| |
| |
| logits = torch.matmul(q, k.transpose(-2, -1)) * sm_scale |
| |
| |
| attn_weights = geometric_attention_activation( |
| logits, mask=None, pos_offset=0, normalize=normalize |
| ) |
| |
| |
| output = torch.matmul(attn_weights, v) |
| |
| |
| if not head_first: |
| output = rearrange(output, "b h t d -> b t h d") |
| |
| |
| if needs_cast: |
| output = output.to(original_dtype) |
| |
| return output |
|
|
|
|
| def geometric_attention( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| *, |
| head_first: bool = False, |
| seq_start: Optional[torch.Tensor] = None, |
| sm_scale: Optional[float] = None, |
| normalize: bool = True, |
| ) -> torch.Tensor: |
| """自动选择CUDA或Python""" |
| |
| if HAS_CUDA and q.is_cuda: |
| try: |
| return geometric_attention_cuda( |
| q, k, v, head_first=head_first, |
| seq_start=seq_start, sm_scale=sm_scale, |
| normalize=normalize |
| ) |
| except Exception as e: |
| |
| pass |
| |
| |
| from forgetting_transformer.ops.geometric_attention_std import geometric_attention_std |
| return geometric_attention_std( |
| q, k, v, head_first=head_first, |
| seq_start=seq_start, sm_scale=sm_scale, |
| normalize=normalize |
| ) |
|
|