|
|
""" |
|
|
Geometric Attention - CUDA加速版本 (支持FP16) |
|
|
""" |
|
|
|
|
|
import math |
|
|
import torch |
|
|
from einops import rearrange |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
try: |
|
|
from forgetting_transformer.ops.geometric_attention.cuda_interface import ( |
|
|
load_extension, |
|
|
geometric_attention_activation, |
|
|
) |
|
|
load_extension() |
|
|
HAS_CUDA = True |
|
|
print("✅ Using CUDA geometric attention (with FP16 support)") |
|
|
except Exception as e: |
|
|
HAS_CUDA = False |
|
|
print(f"⚠️ CUDA not available: {e}") |
|
|
|
|
|
|
|
|
def geometric_attention_cuda( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
*, |
|
|
head_first: bool = False, |
|
|
seq_start: Optional[torch.Tensor] = None, |
|
|
sm_scale: Optional[float] = None, |
|
|
normalize: bool = True, |
|
|
) -> torch.Tensor: |
|
|
if not HAS_CUDA: |
|
|
raise RuntimeError("CUDA not available") |
|
|
|
|
|
|
|
|
original_dtype = q.dtype |
|
|
needs_cast = original_dtype == torch.float16 |
|
|
|
|
|
|
|
|
if needs_cast: |
|
|
q = q.float() |
|
|
k = k.float() |
|
|
v = v.float() |
|
|
|
|
|
|
|
|
if not head_first: |
|
|
q = rearrange(q, "b t h d -> b h t d") |
|
|
k = rearrange(k, "b t h d -> b h t d") |
|
|
v = rearrange(v, "b t h d -> b h t d") |
|
|
|
|
|
B, H, T_q, D = q.shape |
|
|
|
|
|
if sm_scale is None: |
|
|
sm_scale = 1.0 / math.sqrt(D) |
|
|
|
|
|
|
|
|
logits = torch.matmul(q, k.transpose(-2, -1)) * sm_scale |
|
|
|
|
|
|
|
|
attn_weights = geometric_attention_activation( |
|
|
logits, mask=None, pos_offset=0, normalize=normalize |
|
|
) |
|
|
|
|
|
|
|
|
output = torch.matmul(attn_weights, v) |
|
|
|
|
|
|
|
|
if not head_first: |
|
|
output = rearrange(output, "b h t d -> b t h d") |
|
|
|
|
|
|
|
|
if needs_cast: |
|
|
output = output.to(original_dtype) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
def geometric_attention( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
*, |
|
|
head_first: bool = False, |
|
|
seq_start: Optional[torch.Tensor] = None, |
|
|
sm_scale: Optional[float] = None, |
|
|
normalize: bool = True, |
|
|
) -> torch.Tensor: |
|
|
"""自动选择CUDA或Python""" |
|
|
|
|
|
if HAS_CUDA and q.is_cuda: |
|
|
try: |
|
|
return geometric_attention_cuda( |
|
|
q, k, v, head_first=head_first, |
|
|
seq_start=seq_start, sm_scale=sm_scale, |
|
|
normalize=normalize |
|
|
) |
|
|
except Exception as e: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
from forgetting_transformer.ops.geometric_attention_std import geometric_attention_std |
|
|
return geometric_attention_std( |
|
|
q, k, v, head_first=head_first, |
|
|
seq_start=seq_start, sm_scale=sm_scale, |
|
|
normalize=normalize |
|
|
) |
|
|
|