Text Generation
Transformers
Safetensors
English
Arabic
quasar_long
silx-ai
quasar-preview
quasar
foundation-model
Mixture of Experts
18b
2b-active
long-context
bittensor
sn24
decentralized-training
distillation
hybrid-transformer
loop-transformer
safe-nope
drope
conversational
custom_code
Instructions to use mainline777/base_IIXIV with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use mainline777/base_IIXIV with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="mainline777/base_IIXIV", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("mainline777/base_IIXIV", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use mainline777/base_IIXIV with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "mainline777/base_IIXIV" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "mainline777/base_IIXIV", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/mainline777/base_IIXIV
- SGLang
How to use mainline777/base_IIXIV with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "mainline777/base_IIXIV" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "mainline777/base_IIXIV", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "mainline777/base_IIXIV" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "mainline777/base_IIXIV", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use mainline777/base_IIXIV with Docker Model Runner:
docker model run hf.co/mainline777/base_IIXIV
| # Copyright (c) 2023-2025, Tri Dao, Yu Zhang, Songlin Yang. | |
| import torch | |
| import torch.nn.functional as F | |
| import triton | |
| import triton.language as tl | |
| from fla.ops.utils.op import exp, log | |
| from fla.utils import IS_AMD, autocast_custom_bwd, autocast_custom_fwd, autotune_cache_kwargs, input_guard | |
| NUM_WARPS_AUTOTUNE = [1, 2, 4, 8, 16] if IS_AMD else [1, 2, 4, 8, 16, 32] | |
| def _get_stride(x: torch.Tensor) -> int: | |
| """Get the row stride for viewing a tensor as 2D (num_rows, D) where D = shape[-1]. | |
| Returns stride(-2) if the tensor is at least 2D, or 0 for 1D tensors. | |
| The caller must ensure the tensor is "inner-contiguous" (stride(-1) == 1 and | |
| higher dims are contiguous relative to dim -2) before using this value. | |
| """ | |
| if x.ndim < 2: | |
| return 0 | |
| return x.stride(-2) | |
| def _is_inner_contiguous(x: torch.Tensor) -> bool: | |
| """Check if a tensor can be safely viewed as 2D (num_rows, D) with row stride = stride(-2). | |
| This holds when stride(-1) == 1 and all dimensions above -2 are contiguous | |
| with respect to the dimension below them. | |
| """ | |
| ndim = x.ndim | |
| if ndim < 2: | |
| return True | |
| if x.stride(-1) != 1: | |
| return False | |
| if ndim == 2: | |
| # 2D: any layout with stride(-1)==1 is valid (can view as (T, D)) | |
| return True | |
| if ndim == 3: | |
| # 3D (B, T, D): stride should be (T*D, D, 1) | |
| return x.stride(0) == x.stride(-2) * x.shape[-2] | |
| if ndim == 4: | |
| # 4D (B, H, T, D): stride should be (H*T*D, T*D, D, 1) | |
| if x.stride(1) != x.stride(-2) * x.shape[-2]: | |
| return False | |
| return x.stride(0) == x.stride(1) * x.shape[1] | |
| # 5D+ fallback to loop | |
| expected = x.stride(-2) * x.shape[-2] | |
| for d in range(ndim - 3, -1, -1): | |
| if x.stride(d) != expected: | |
| return False | |
| expected *= x.shape[d] | |
| return True | |
| def _ensure_inner_contiguous(x: torch.Tensor) -> torch.Tensor: | |
| """Make the tensor inner-contiguous if it isn't already.""" | |
| if _is_inner_contiguous(x): | |
| return x | |
| return x.contiguous() | |
| def _alloc_output(x: torch.Tensor, contiguous: bool = False) -> torch.Tensor: | |
| """Allocate output tensor: contiguous buffer or same layout as input.""" | |
| if contiguous: | |
| return x.new_empty(x.shape) | |
| return torch.empty_like(x) | |
| def sigmoid_fwd_kernel( | |
| x, y, | |
| T, | |
| D: tl.constexpr, | |
| stride_x_row, | |
| stride_y_row, | |
| B: tl.constexpr, | |
| ): | |
| pid = tl.program_id(0) | |
| offs = pid * B + tl.arange(0, B) | |
| mask = offs < T | |
| row = offs // D | |
| col = offs % D | |
| x_off = row * stride_x_row + col | |
| y_off = row * stride_y_row + col | |
| x_val = tl.load(x + x_off, mask=mask, other=0.).to(tl.float32) | |
| y_val = 1.0 / (1.0 + exp(-x_val)) | |
| tl.store(y + y_off, y_val.to(y.dtype.element_ty), mask=mask) | |
| def sigmoid_bwd_kernel( | |
| x, dy, dx, | |
| T, | |
| D: tl.constexpr, | |
| stride_x_row, | |
| stride_dy_row, | |
| stride_dx_row, | |
| B: tl.constexpr, | |
| ): | |
| pid = tl.program_id(0) | |
| offs = pid * B + tl.arange(0, B) | |
| mask = offs < T | |
| row = offs // D | |
| col = offs % D | |
| x_off = row * stride_x_row + col | |
| dy_off = row * stride_dy_row + col | |
| dx_off = row * stride_dx_row + col | |
| x_val = tl.load(x + x_off, mask=mask, other=0.).to(tl.float32) | |
| g_val = tl.load(dy + dy_off, mask=mask, other=0.).to(tl.float32) | |
| s = 1.0 / (1.0 + exp(-x_val)) | |
| dx_val = g_val * s * (1.0 - s) | |
| tl.store(dx + dx_off, dx_val.to(dx.dtype.element_ty), mask=mask) | |
| def sigmoid_fwd(x: torch.Tensor, output_contiguous: bool = False) -> torch.Tensor: | |
| x = _ensure_inner_contiguous(x) | |
| T, D = x.numel(), x.shape[-1] | |
| y = _alloc_output(x, output_contiguous) | |
| sigmoid_fwd_kernel[lambda meta: (triton.cdiv(T, meta['B']),)]( | |
| x, y, T=T, D=D, | |
| stride_x_row=_get_stride(x), | |
| stride_y_row=_get_stride(y), | |
| ) | |
| return y | |
| def sigmoid_bwd(x: torch.Tensor, dy: torch.Tensor, output_contiguous: bool = False) -> torch.Tensor: | |
| x = _ensure_inner_contiguous(x) | |
| dy = _ensure_inner_contiguous(dy) | |
| T, D = x.numel(), x.shape[-1] | |
| dx = _alloc_output(x, output_contiguous) | |
| sigmoid_bwd_kernel[lambda meta: (triton.cdiv(T, meta['B']),)]( | |
| x, dy, dx, T=T, D=D, | |
| stride_x_row=_get_stride(x), | |
| stride_dy_row=_get_stride(dy), | |
| stride_dx_row=_get_stride(dx), | |
| ) | |
| return dx | |
| class SigmoidFunction(torch.autograd.Function): | |
| def forward(ctx, x): | |
| ctx.save_for_backward(x) | |
| return sigmoid_fwd(x) | |
| def backward(ctx, dout): | |
| x, = ctx.saved_tensors | |
| return sigmoid_bwd(x, dout) | |
| sigmoid = SigmoidFunction.apply | |
| def logsigmoid_fwd_kernel( | |
| x, | |
| y, | |
| temperature, | |
| T, | |
| D: tl.constexpr, | |
| stride_x_row, | |
| stride_y_row, | |
| B: tl.constexpr, | |
| ): | |
| i = tl.program_id(0) | |
| o_i = i * B + tl.arange(0, B) | |
| m_i = o_i < T | |
| row = o_i // D | |
| col = o_i % D | |
| x_off = row * stride_x_row + col | |
| y_off = row * stride_y_row + col | |
| b_x = tl.load(x + x_off, mask=m_i, other=0.).to(tl.float32) | |
| b_m = tl.minimum(0., b_x) | |
| b_z = 1. + exp(-tl.abs(b_x)) | |
| b_y = (b_m - log(b_z)) / temperature | |
| tl.store(y + y_off, b_y.to(y.dtype.element_ty), mask=m_i) | |
| def logsigmoid_bwd_kernel( | |
| x, | |
| dx, | |
| dy, | |
| temperature, | |
| T, | |
| D: tl.constexpr, | |
| stride_x_row, | |
| stride_dx_row, | |
| stride_dy_row, | |
| B: tl.constexpr, | |
| ): | |
| i = tl.program_id(0) | |
| o_i = i * B + tl.arange(0, B) | |
| m_i = o_i < T | |
| row = o_i // D | |
| col = o_i % D | |
| x_off = row * stride_x_row + col | |
| dx_off = row * stride_dx_row + col | |
| dy_off = row * stride_dy_row + col | |
| b_x = tl.load(x + x_off, mask=m_i, other=0.).to(tl.float32) | |
| b_dy = tl.load(dy + dy_off, mask=m_i, other=0.).to(tl.float32) | |
| b_dx = b_dy * ((1. - tl.sigmoid(b_x)) / temperature) | |
| tl.store(dx + dx_off, b_dx.to(dx.dtype.element_ty), mask=m_i) | |
| def logsigmoid_fwd(x: torch.Tensor, temperature: float = 1., output_contiguous: bool = False) -> torch.Tensor: | |
| x = _ensure_inner_contiguous(x) | |
| T, D = x.numel(), x.shape[-1] | |
| y = _alloc_output(x, output_contiguous) | |
| logsigmoid_fwd_kernel[lambda meta: (triton.cdiv(T, meta['B']),)]( | |
| x=x, | |
| y=y, | |
| temperature=temperature, | |
| T=T, | |
| D=D, | |
| stride_x_row=_get_stride(x), | |
| stride_y_row=_get_stride(y), | |
| ) | |
| return y | |
| def logsigmoid_bwd(x: torch.Tensor, dy: torch.Tensor, temperature: float = 1., output_contiguous: bool = False) -> torch.Tensor: | |
| x = _ensure_inner_contiguous(x) | |
| dy = _ensure_inner_contiguous(dy) | |
| T, D = x.numel(), x.shape[-1] | |
| dx = _alloc_output(x, output_contiguous) | |
| logsigmoid_bwd_kernel[lambda meta: (triton.cdiv(T, meta['B']),)]( | |
| x=x, | |
| dx=dx, | |
| dy=dy, | |
| temperature=temperature, | |
| T=T, | |
| D=D, | |
| stride_x_row=_get_stride(x), | |
| stride_dx_row=_get_stride(dx), | |
| stride_dy_row=_get_stride(dy), | |
| ) | |
| return dx | |
| class LogSigmoidFunction(torch.autograd.Function): | |
| def forward(ctx, x, temperature): | |
| ctx.save_for_backward(x) | |
| ctx.temperature = temperature | |
| return logsigmoid_fwd(x, temperature) | |
| def backward(ctx, dy): | |
| x, = ctx.saved_tensors | |
| return logsigmoid_bwd(x, dy, ctx.temperature), None | |
| def logsigmoid(x: torch.Tensor, temperature: float = 1.) -> torch.Tensor: | |
| return LogSigmoidFunction.apply(x, temperature) | |
| def swish_fwd_kernel( | |
| x, y, | |
| T, | |
| D: tl.constexpr, | |
| stride_x_row, | |
| stride_y_row, | |
| B: tl.constexpr, | |
| ): | |
| pid = tl.program_id(0) | |
| offs = pid * B + tl.arange(0, B) | |
| mask = offs < T | |
| row = offs // D | |
| col = offs % D | |
| x_off = row * stride_x_row + col | |
| y_off = row * stride_y_row + col | |
| x_val = tl.load(x + x_off, mask=mask, other=0.).to(tl.float32) | |
| s = 1.0 / (1.0 + exp(-x_val)) | |
| y_val = x_val * s | |
| tl.store(y + y_off, y_val.to(y.dtype.element_ty), mask=mask) | |
| def swish_bwd_kernel( | |
| x, dy, dx, | |
| T, | |
| D: tl.constexpr, | |
| stride_x_row, | |
| stride_dy_row, | |
| stride_dx_row, | |
| B: tl.constexpr, | |
| ): | |
| pid = tl.program_id(0) | |
| offs = pid * B + tl.arange(0, B) | |
| mask = offs < T | |
| row = offs // D | |
| col = offs % D | |
| x_off = row * stride_x_row + col | |
| dy_off = row * stride_dy_row + col | |
| dx_off = row * stride_dx_row + col | |
| x_val = tl.load(x + x_off, mask=mask, other=0.).to(tl.float32) | |
| g_val = tl.load(dy + dy_off, mask=mask, other=0.).to(tl.float32) | |
| s = 1.0 / (1.0 + exp(-x_val)) | |
| dx_val = g_val * s * (1.0 + x_val * (1.0 - s)) | |
| tl.store(dx + dx_off, dx_val.to(dx.dtype.element_ty), mask=mask) | |
| def swish_fwd(x: torch.Tensor, output_contiguous: bool = False) -> torch.Tensor: | |
| x = _ensure_inner_contiguous(x) | |
| T, D = x.numel(), x.shape[-1] | |
| y = _alloc_output(x, output_contiguous) | |
| swish_fwd_kernel[lambda meta: (triton.cdiv(T, meta['B']),)]( | |
| x, y, T=T, D=D, | |
| stride_x_row=_get_stride(x), | |
| stride_y_row=_get_stride(y), | |
| ) | |
| return y | |
| def swish_bwd(x: torch.Tensor, dy: torch.Tensor, output_contiguous: bool = False) -> torch.Tensor: | |
| x = _ensure_inner_contiguous(x) | |
| dy = _ensure_inner_contiguous(dy) | |
| T, D = x.numel(), x.shape[-1] | |
| dx = _alloc_output(x, output_contiguous) | |
| swish_bwd_kernel[lambda meta: (triton.cdiv(T, meta['B']),)]( | |
| x, dy, dx, T=T, D=D, | |
| stride_x_row=_get_stride(x), | |
| stride_dy_row=_get_stride(dy), | |
| stride_dx_row=_get_stride(dx), | |
| ) | |
| return dx | |
| class SwishFunction(torch.autograd.Function): | |
| def forward(ctx, x): | |
| ctx.save_for_backward(x) | |
| return swish_fwd(x) | |
| def backward(ctx, dout): | |
| x, = ctx.saved_tensors | |
| return swish_bwd(x, dout) | |
| swish = SwishFunction.apply | |
| # 1/sqrt(2*pi)-> 0.3989423 | |
| # 1/sqrt(2) -> 0.70710678 | |
| # sqrt(2/pi) -> 0.79788456 | |
| # this function is tanh approximation of gelu | |
| # actual gelu is: | |
| # x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) | |
| def bias_gelu(y, bias): | |
| x = bias + y | |
| return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype) | |
| # gradient of tanh approximation of gelu | |
| # gradient of actual gelu is: | |
| # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) | |
| def bias_gelu_bwd(g, y, bias): | |
| """Assume that y has shape (B, D=D) and bias has shape (D)""" | |
| x = bias + y | |
| tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) | |
| # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 | |
| ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( | |
| 1 + tanh_out | |
| ) | |
| grad_y = ff * g | |
| return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype) | |
| class GeLUFunction(torch.autograd.Function): | |
| # bias is an optional argument | |
| def forward(ctx, input, bias): | |
| ctx.save_for_backward(input, bias) | |
| return bias_gelu(input, bias) | |
| def backward(ctx, grad_output): | |
| input, bias = ctx.saved_tensors | |
| tmp = bias_gelu_bwd(grad_output, input, bias) | |
| return tmp, tmp | |
| bias_gelu_impl = GeLUFunction.apply | |
| # this function is tanh approximation of gelu | |
| # actual gelu is: | |
| # x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) | |
| def gelu_fwd(x): | |
| return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype) | |
| # gradient of tanh approximation of gelu | |
| # gradient of actual gelu is: | |
| # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) | |
| def gelu_bwd(g, x): | |
| tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) | |
| # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 | |
| ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( | |
| 1 + tanh_out | |
| ) | |
| return (ff * g).to(dtype=x.dtype) | |
| class FastGeLUFunction(torch.autograd.Function): | |
| # bias is an optional argument | |
| def forward(ctx, input): | |
| ctx.save_for_backward(input) | |
| return gelu_fwd(input) | |
| def backward(ctx, grad_output): | |
| (input,) = ctx.saved_tensors | |
| tmp = gelu_bwd(grad_output, input) | |
| return tmp | |
| fast_gelu_impl = FastGeLUFunction.apply | |
| def relu_bwd(g, x): | |
| return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype) | |
| def sqrelu_fwd(x): | |
| r = F.relu(x.float()) | |
| return (r * r).to(dtype=x.dtype) | |
| def sqrelu_bwd(g, x): | |
| return (2.0 * g * F.relu(x.float())).to(dtype=x.dtype) | |
| class SquaredReLUFunction(torch.autograd.Function): | |
| def forward(ctx, input): | |
| ctx.save_for_backward(input) | |
| return sqrelu_fwd(input) | |
| def backward(ctx, grad_output): | |
| input, = ctx.saved_tensors | |
| return sqrelu_bwd(grad_output, input) | |
| sqrelu = SquaredReLUFunction.apply | |
| def swiglu_fwd_kernel( | |
| x, y, z, | |
| T, | |
| D: tl.constexpr, | |
| stride_x_row, | |
| stride_y_row, | |
| stride_z_row, | |
| B: tl.constexpr, | |
| ): | |
| pid = tl.program_id(0) | |
| offs = pid * B + tl.arange(0, B) | |
| mask = offs < T | |
| row = offs // D | |
| col = offs % D | |
| x_off = row * stride_x_row + col | |
| y_off = row * stride_y_row + col | |
| z_off = row * stride_z_row + col | |
| x_val = tl.load(x + x_off, mask=mask, other=0.).to(tl.float32) | |
| y_val = tl.load(y + y_off, mask=mask, other=0.).to(tl.float32) | |
| s = 1.0 / (1.0 + exp(-x_val)) | |
| z_val = x_val * s * y_val | |
| tl.store(z + z_off, z_val.to(z.dtype.element_ty), mask=mask) | |
| def swiglu_fwdbwd_kernel( | |
| x, y, g, dx, dy, z, | |
| T, | |
| D: tl.constexpr, | |
| stride_x_row, | |
| stride_y_row, | |
| stride_g_row, | |
| stride_dx_row, | |
| stride_dy_row, | |
| stride_z_row, | |
| B: tl.constexpr, | |
| HAS_WEIGHT: tl.constexpr, | |
| ): | |
| pid = tl.program_id(0) | |
| offs = pid * B + tl.arange(0, B) | |
| mask = offs < T | |
| row = offs // D | |
| col = offs % D | |
| x_off = row * stride_x_row + col | |
| y_off = row * stride_y_row + col | |
| g_off = row * stride_g_row + col | |
| dx_off = row * stride_dx_row + col | |
| dy_off = row * stride_dy_row + col | |
| x_val = tl.load(x + x_off, mask=mask, other=0.).to(tl.float32) | |
| y_val = tl.load(y + y_off, mask=mask, other=0.).to(tl.float32) | |
| g_val = tl.load(g + g_off, mask=mask, other=0.).to(tl.float32) | |
| s = 1.0 / (1.0 + exp(-x_val)) | |
| x_s = x_val * s | |
| dx_val = g_val * s * (1.0 + x_val * (1.0 - s)) * y_val | |
| dy_val = g_val * x_s | |
| tl.store(dx + dx_off, dx_val.to(dx.dtype.element_ty), mask=mask) | |
| tl.store(dy + dy_off, dy_val.to(dy.dtype.element_ty), mask=mask) | |
| if HAS_WEIGHT: | |
| z_off = row * stride_z_row + col | |
| z_val = x_s * y_val | |
| tl.store(z + z_off, z_val.to(z.dtype.element_ty), mask=mask) | |
| def swiglu_fwd(x: torch.Tensor, y: torch.Tensor, output_contiguous: bool = False) -> torch.Tensor: | |
| assert x.shape == y.shape, f"swiglu_fwd: shape mismatch x={x.shape} y={y.shape}" | |
| x = _ensure_inner_contiguous(x) | |
| y = _ensure_inner_contiguous(y) | |
| T, D = x.numel(), x.shape[-1] | |
| z = _alloc_output(x, output_contiguous) | |
| swiglu_fwd_kernel[lambda meta: (triton.cdiv(T, meta['B']),)]( | |
| x, y, z, T=T, D=D, | |
| stride_x_row=_get_stride(x), | |
| stride_y_row=_get_stride(y), | |
| stride_z_row=_get_stride(z), | |
| ) | |
| return z | |
| def swiglu_fwdbwd( | |
| x: torch.Tensor, | |
| y: torch.Tensor, | |
| g: torch.Tensor, | |
| use_weight: bool = False, | |
| output_contiguous: bool = False, | |
| ): | |
| assert x.shape == y.shape == g.shape, f"swiglu_fwdbwd: shape mismatch x={x.shape} y={y.shape} g={g.shape}" | |
| x = _ensure_inner_contiguous(x) | |
| y = _ensure_inner_contiguous(y) | |
| g = _ensure_inner_contiguous(g) | |
| T, D = x.numel(), x.shape[-1] | |
| dx = _alloc_output(x, output_contiguous) | |
| dy = _alloc_output(y, output_contiguous) | |
| if use_weight: | |
| z = _alloc_output(x, output_contiguous) | |
| else: | |
| z = None | |
| swiglu_fwdbwd_kernel[lambda meta: (triton.cdiv(T, meta['B']),)]( | |
| x, y, g, dx, dy, z, T=T, D=D, | |
| stride_x_row=_get_stride(x), | |
| stride_y_row=_get_stride(y), | |
| stride_g_row=_get_stride(g), | |
| stride_dx_row=_get_stride(dx), | |
| stride_dy_row=_get_stride(dy), | |
| stride_z_row=_get_stride(z) if z is not None else 0, | |
| ) | |
| if use_weight: | |
| return dx, dy, z | |
| return dx, dy | |
| class SwiGLUFunction(torch.autograd.Function): | |
| r""" | |
| Swish-Gated Linear Unit (SwiGLU) function. | |
| .. math:: | |
| \text{SwiGLU}(x, y) = swish(x) * y = \frac{x}{1 + \exp(-x)} * y | |
| """ | |
| def forward(ctx, x, y): | |
| ctx.save_for_backward(x, y) | |
| return swiglu_fwd(x, y) | |
| def backward(ctx, dout): | |
| x, y = ctx.saved_tensors | |
| return swiglu_fwdbwd(x, y, dout) | |
| class SwiGLULinearFunction(torch.autograd.Function): | |
| r""" | |
| Swish-Gated Linear Unit (SwiGLU) function followed by a linear transformation. | |
| .. math:: | |
| \text{SwiGLULinear}(x, y, W, b) = (swish(x) * y) W + b | |
| This simple wrap discards the intermediate results of SwiGLU(x, y) to save memory. | |
| """ | |
| def forward(ctx, x, y, weight, bias): | |
| z = swiglu_fwd(x, y, output_contiguous=True) | |
| out = F.linear(z, weight, bias) | |
| ctx.save_for_backward(x, y, weight) | |
| ctx.linear_bias_is_none = bias is None | |
| return out | |
| def backward(ctx, dout, *args): | |
| x, y, weight = ctx.saved_tensors | |
| dout = dout.reshape(-1, dout.shape[-1]) | |
| dz = F.linear(dout, weight.t()).view_as(x) | |
| dx, dy, z = swiglu_fwdbwd(x, y, dz, use_weight=True, output_contiguous=True) | |
| dlinear_weight = torch.einsum("bo,bi->oi", dout, z.reshape(-1, z.shape[-1])) | |
| dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) | |
| return dx, dy, dlinear_weight, dlinear_bias | |
| swiglu = SwiGLUFunction.apply | |
| swiglu_linear = SwiGLULinearFunction.apply | |
| ACT2FN = { | |
| 'relu': F.relu, | |
| 'sigmoid': sigmoid, | |
| 'logsigmoid': logsigmoid, | |
| 'silu': swish, | |
| 'swish': swish, | |
| 'sqrelu': sqrelu, | |
| 'gelu': fast_gelu_impl, | |
| 'bias_gelu': bias_gelu_impl, | |
| } | |