Lanni-ni's picture
add remote code + model files
15063d0 verified
import os
import torch
import multiprocessing
from typing import Tuple, Optional
import torch.nn.functional as F
import filelock # 用filelock替代framework.utils.LockFile
# Just in time import
# https://pytorch.org/tutorials/advanced/cpp_extension
dirname = os.path.dirname(__file__)
filename = os.path.join(dirname, 'cuda_interface.cu')
outdir = "./cache/geometric_attention"
os.makedirs(outdir, exist_ok=True)
cuda_log_sigmoid_backward = None
cuda_log_sigmoid_forward = None
cuda_window_sum_forward = None
cuda_window_sum_backward = None
def load_extension():
global cuda_log_sigmoid_forward, cuda_log_sigmoid_backward
global cuda_window_sum_forward, cuda_window_sum_backward
if cuda_log_sigmoid_forward is not None:
return
# 使用filelock替代framework.utils.LockFile
lock = filelock.FileLock(outdir + "/lock.lock")
with lock:
from torch.utils.cpp_extension import load
os.environ["MAX_JOBS"] = str(multiprocessing.cpu_count())
ext = load(
extra_cuda_cflags=['--ftemplate-depth=1024'],
name="geometric_attention_cuda_interface",
sources=[filename], verbose=True)
cuda_log_sigmoid_forward = ext.cuda_log_sigmoid_forward
cuda_log_sigmoid_backward = ext.cuda_log_sigmoid_backward
cuda_window_sum_forward = ext.cuda_window_sum_forward
cuda_window_sum_backward = ext.cuda_window_sum_backward
class LogSigmoidFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x = x.detach().contiguous()
ctx.save_for_backward(x)
a, b = cuda_log_sigmoid_forward(x)
return a, b
@staticmethod
def backward(ctx, grad_in_sigm: torch.Tensor, grad_in_one_minus: torch.tensor) -> torch.Tensor:
xf, = ctx.saved_tensors
ga = grad_in_sigm.contiguous()
gb = grad_in_one_minus.contiguous()
return cuda_log_sigmoid_backward(xf, ga, gb)[0]
class WindowSumFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, csum: torch.Tensor, offset: int) -> torch.Tensor:
ctx.saved_offset = offset
c2 = csum.detach().contiguous().flatten(end_dim=-3)
res = cuda_window_sum_forward(c2, offset)
return res.view_as(csum)
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
offset = ctx.saved_offset
go = grad_output.contiguous().flatten(end_dim=-3)
res = cuda_window_sum_backward(go, offset)
return res.view_as(grad_output), None
def window_sum(x: torch.Tensor, offset: int) -> torch.Tensor:
load_extension()
return WindowSumFunction.apply(x, offset)
def log_sigmoid(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
load_extension()
return LogSigmoidFunction.apply(x)
def geometric_attention_activation(logits: torch.Tensor, mask: Optional[torch.Tensor] = None, pos_offset: int = 0,
normalize: bool = True) -> torch.Tensor:
p, one_minus_p = log_sigmoid(logits)
not_previos = window_sum(one_minus_p.cumsum(-1), pos_offset)
probs = (not_previos + p).exp()
return F.normalize(probs, 1, -1) if normalize else probs