Ev3Dev's picture
Upload folder using huggingface_hub
ad39f2a verified
# Unsloth Zoo - Utilities for Unsloth
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import torch
import torch.nn.functional as F
import os
import shutil
import sys
import importlib.util
from typing import Optional, Tuple
from torch.autograd import Function
# Get compile location
UNSLOTH_COMPILE_LOCATION = os.environ.get(
"UNSLOTH_COMPILE_LOCATION", "unsloth_compiled_cache"
)
def _get_compile_location() -> str:
return os.path.abspath(
os.environ.get("UNSLOTH_COMPILE_LOCATION", UNSLOTH_COMPILE_LOCATION)
)
def _log_info(message: str):
if os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1":
print(message)
def install_to_cache(source_path, destination_filename=None):
"""
Copies a file to the unsloth_compiled_cache directory
to ensure it is available for compiled modules.
"""
compile_location = _get_compile_location()
if not os.path.exists(compile_location):
try:
os.makedirs(compile_location)
except:
pass
current_file = os.path.abspath(source_path)
if destination_filename is None:
destination_filename = os.path.basename(current_file)
destination = os.path.abspath(os.path.join(compile_location, destination_filename))
# If source and dest are different, copy.
if current_file != destination:
try:
shutil.copy(current_file, destination)
except Exception:
pass
install_to_cache(__file__, "moe_utils.py")
_CACHED_FORWARD_MOE_BACKEND = None
_CACHED_MOE_UTILS_MODULE = None
def _load_cached_moe_utils_module():
global _CACHED_MOE_UTILS_MODULE
cache_file = os.path.abspath(os.path.join(_get_compile_location(), "moe_utils.py"))
current_file = os.path.abspath(__file__)
if not os.path.isfile(cache_file) or cache_file == current_file:
return None
try:
module_name = "unsloth_cached_moe_utils"
module = sys.modules.get(module_name, None)
if module is not None and os.path.abspath(getattr(module, "__file__", "")) == cache_file:
_CACHED_MOE_UTILS_MODULE = module
return module
spec = importlib.util.spec_from_file_location(module_name, cache_file)
if spec is None or spec.loader is None:
return None
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
_CACHED_MOE_UTILS_MODULE = module
return module
except Exception:
return None
def get_forward_moe_backend():
"""
Resolve forward_moe_backend from the compiled cache copy when available.
Falls back to the local module definition.
"""
global _CACHED_FORWARD_MOE_BACKEND
module = _load_cached_moe_utils_module()
if module is not None and hasattr(module, "forward_moe_backend"):
_CACHED_FORWARD_MOE_BACKEND = module.forward_moe_backend
return _CACHED_FORWARD_MOE_BACKEND
_CACHED_FORWARD_MOE_BACKEND = forward_moe_backend
return _CACHED_FORWARD_MOE_BACKEND
# ============================================================================
# Grouped MM wrapper
# ============================================================================
# Simple wrapper around torch._grouped_mm that ensures contiguous inputs.
# Native backward works correctly - no custom autograd needed.
# ============================================================================
def _grouped_mm_with_backward_fix(
inputs: torch.Tensor, weight: torch.Tensor, offsets: torch.Tensor
) -> torch.Tensor:
"""
Grouped matmul with working backward pass.
Uses native torch._grouped_mm with contiguous inputs for correct gradients.
"""
return torch._grouped_mm(inputs, weight, offs=offsets)
# Global flag to check if grouped GEMM is available
_GROUPED_GEMM_AVAILABLE = None
_TORCH_GROUPED_MM_AVAILABLE = hasattr(torch, "_grouped_mm")
# Check if GPU supports torch._grouped_mm (verified via runtime check)
_TORCH_GROUPED_MM_SUPPORTED = None
def _check_torch_grouped_mm_supported():
"""
Check if torch._grouped_mm is actually supported on the current GPU.
We check for existence and verify with a dummy call.
A runtime probe is the only reliable check.
"""
global _TORCH_GROUPED_MM_SUPPORTED
if _TORCH_GROUPED_MM_SUPPORTED is not None: return _TORCH_GROUPED_MM_SUPPORTED
if not _TORCH_GROUPED_MM_AVAILABLE:
_TORCH_GROUPED_MM_SUPPORTED = False
return False
if not torch.cuda.is_available():
_TORCH_GROUPED_MM_SUPPORTED = False
return False
try:
# Attempt a dummy grouped_mm call to verify support.
# This handles cases where the symbol exists but hardware is unsupported (e.g. < H100).
# It also allows support on newer hardware or backports without code changes.
device = torch.cuda.current_device()
dtype = torch.float16
# Minimal dummy data: 1 expert, 1 token, dim 8 (safe alignment)
x = torch.ones((1, 8), device=device, dtype=dtype)
w = torch.ones((1, 8, 8), device=device, dtype=dtype)
offs = torch.tensor([1], device=device, dtype=torch.int32)
torch._grouped_mm(x, w, offs=offs)
del x, w, offs
_TORCH_GROUPED_MM_SUPPORTED = True
except Exception:
_TORCH_GROUPED_MM_SUPPORTED = False
return _TORCH_GROUPED_MM_SUPPORTED
_TRITON_ALLOCATOR_INITIALIZED = False
_PERSISTENT_BUFFER = None
def _init_triton_allocator():
"""
Initialize a persistent Triton allocator to avoid memory allocation overhead per call.
This significantly reduces GPU utilization fluctuation.
"""
global _TRITON_ALLOCATOR_INITIALIZED, _PERSISTENT_BUFFER
if _TRITON_ALLOCATOR_INITIALIZED: return
try:
import triton
# Create a persistent buffer that grows as needed
# This avoids allocating new memory on every kernel call
def persistent_alloc_fn(size: int, alignment: int, stream):
global _PERSISTENT_BUFFER
# Round up size to avoid frequent reallocations
# Round to nearest 128 bytes for alignment
rounded_size = ((size + 128 - 1) // 128) * 128
if (
_PERSISTENT_BUFFER is None
or _PERSISTENT_BUFFER.numel() * _PERSISTENT_BUFFER.element_size()
< rounded_size
):
# Allocate with small headroom (10%) to reduce reallocations
# Use ByteTensor (uint8) for raw byte storage
_PERSISTENT_BUFFER = torch.empty(
int(rounded_size * 1.1), device="cuda", dtype=torch.uint8
)
_PERSISTENT_BUFFER.__hibernate__ = {"type": "ignore"}
return _PERSISTENT_BUFFER
triton.set_allocator(persistent_alloc_fn)
triton._unsloth_allocator_set = True
_TRITON_ALLOCATOR_INITIALIZED = True
except Exception:
pass
def _check_grouped_gemm_available():
"""Check if Unsloth grouped GEMM kernels are available."""
if os.environ.get("UNSLOTH_DISABLE_MOE_TRITON", "0") == "1": return False
global _GROUPED_GEMM_AVAILABLE
if _GROUPED_GEMM_AVAILABLE is not None: return _GROUPED_GEMM_AVAILABLE
try:
from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm, supports_tma
_GROUPED_GEMM_AVAILABLE = True
_init_triton_allocator()
except (ImportError, ModuleNotFoundError):
_GROUPED_GEMM_AVAILABLE = False
return _GROUPED_GEMM_AVAILABLE
from functools import lru_cache
@lru_cache(maxsize=1)
def select_moe_backend():
"""
Selects the MoE backend based on UNSLOTH_MOE_BACKEND environment variable and availability.
Choices: "grouped_mm", "unsloth_triton", "native_torch".
Default if unspecified: "grouped_mm".
"""
# This Unsloth Zoo code section is licensed under AGPL3
requested = os.environ.get("UNSLOTH_MOE_BACKEND")
if requested:
if requested == "grouped_mm" and _check_torch_grouped_mm_supported():
return "grouped_mm"
if requested == "unsloth_triton" and _check_grouped_gemm_available():
return "unsloth_triton"
if requested == "native_torch":
return "native_torch"
_log_info(f"Unsloth: '{requested}' backend requested but is not available. Falling back to next available.")
if _check_torch_grouped_mm_supported():
_log_info("Unsloth: Using MoE backend 'grouped_mm'")
return "grouped_mm"
if _check_grouped_gemm_available():
_log_info("Unsloth: Using MoE backend 'unsloth_triton'")
return "unsloth_triton"
return "native_torch"
def forward_moe_backend(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
"""
Dispatch MoE forward to the selected backend.
Centralizes backend selection to keep model-specific patches minimal.
"""
# This Unsloth Zoo code section is licensed under AGPL3
backend = select_moe_backend()
if backend == "grouped_mm":
return forward_native_grouped_mm(self, hidden_states, top_k_index, top_k_weights)
if backend == "unsloth_triton":
return forward_triton_grouped_gemm(self, hidden_states, top_k_index, top_k_weights)
return forward_native_moe_loop(self, hidden_states, top_k_index, top_k_weights)
@torch.no_grad()
def _get_routing_indices(selected_experts, num_experts):
"""
Compute token→expert mapping for grouped GEMM.
Uses bincount instead of histc to avoid float conversion overhead.
Returns:
token_counts_by_expert: (num_experts,) token counts per expert
gather_indices: (total_tokens,) indices for gathering tokens in expert order
"""
# This Unsloth Zoo code section is licensed under AGPL3
flat_experts = selected_experts.view(-1)
# bincount is faster than histc since it doesn't require float conversion
token_counts_by_expert = torch.bincount(flat_experts, minlength=num_experts).to(torch.int32)
# argsort with stable=True preserves order within each expert
gather_indices = flat_experts.argsort(stable=True)
return token_counts_by_expert, gather_indices
def _silu_and_mul(x):
"""Fused SiLU activation and element-wise multiply for gate/up projections."""
gate, up = x.chunk(2, dim=-1)
return F.silu(gate) * up
# ============================================================================
# Separated LoRA Helper Functions
# ============================================================================
def _has_lora_adapters(param) -> bool:
"""Check if parameter has active LoRA adapters (PEFT ParamWrapper)."""
# Check if this is a PEFT LoRA wrapper
if not hasattr(param, "lora_A") or not hasattr(param, "lora_B"):
return False
if hasattr(param, "disable_adapters") and param.disable_adapters:
return False
if hasattr(param, "merged") and param.merged:
return False
return len(param.lora_A) > 0
def _extract_lora_from_wrapper(
wrapper, adapter_name: str = "default", experts_module=None
) -> Optional[Tuple[torch.Tensor, torch.Tensor, float, int]]:
"""
Extract LoRA weights from PEFT ParamWrapper for MoE separated computation.
PEFT ParamWrapper for 3D parameters creates:
- lora_A: nn.Linear(in_dim, E*R) -> weight: (E*R, in_dim)
- lora_B: nn.Linear(E*R, out_dim) -> weight: (out_dim, E*R)
For grouped_mm: X @ first_weight @ second_weight
STANDARD FORMAT (Qwen3-MoE): weights stored as (E, out_dim, in_dim) for F.linear
gate_up_proj: (E, 2*I, H) - input X is (N, H), output is (N, 2*I)
down_proj: (E, H, I) - input X is (N, I), output is (N, H)
For gate_up with (E, 2*I, H):
lora_A: (E*R, H), lora_B: (2*I, E*R)
Input X (N, H) needs: X @ (E, H, R) @ (E, R, 2*I) -> (N, 2*I)
first_weight from lora_A: (E*R, H) -> (E, H, R) after view/permute
second_weight from lora_B: (2*I, E*R) -> (E, R, 2*I) after view/permute
TRANSPOSED FORMAT (Qwen3-VL-MoE): weights stored as (E, in_dim, out_dim) for grouped_mm
gate_up_proj: (E, H, 2*I) - input X is (N, H), output is (N, 2*I)
down_proj: (E, I, H) - input X is (N, I), output is (N, H)
For gate_up with (E, H, 2*I):
lora_A: (E*R, H), lora_B: (2*I, E*R)
Input X (N, H) needs: X @ (E, H, R) @ (E, R, 2*I) -> (N, 2*I)
first_weight from lora_A: (E*R, H) -> (E, H, R)
second_weight from lora_B: (2*I, E*R) -> (E, R, 2*I)
Returns:
(first_weight, second_weight, scaling, num_experts) or None
"""
# This Unsloth Zoo code section is licensed under AGPL3
try:
if not hasattr(wrapper, "lora_A") or not hasattr(wrapper, "lora_B"):
return None
if hasattr(wrapper, "disable_adapters") and wrapper.disable_adapters:
return None
if hasattr(wrapper, "merged") and wrapper.merged:
return None
if not wrapper.lora_A:
return None
if adapter_name not in wrapper.lora_A:
adapter_name = list(wrapper.lora_A.keys())[0]
lora_A_module = wrapper.lora_A[adapter_name]
lora_B_module = wrapper.lora_B[adapter_name]
weight_A = lora_A_module.weight # (E*R, dim1)
weight_B = lora_B_module.weight # (dim2, E*R)
scaling = wrapper.scaling[adapter_name]
num_experts = getattr(wrapper, "num_experts", 1)
# GET EXPERTS MODULE TO CHECK FOR REGISTERED EXTRACTOR
if experts_module is None:
experts_module = wrapper.get_base_layer() if hasattr(wrapper, "get_base_layer") else None
# Check for model-specific LoRA extractor attached to the experts module
extractor_fn = getattr(experts_module, "_unsloth_lora_extractor_fn", None)
if extractor_fn is not None:
return extractor_fn(wrapper, weight_A, weight_B, scaling, num_experts)
# DEFAULT BEHAVIOR (Standard Format / Non-MoE)
if num_experts > 1:
total_rank = weight_A.shape[0]
rank_per_expert = total_rank // num_experts
dim1 = weight_A.shape[1]
dim2 = weight_B.shape[0]
# STANDARD FORMAT (Qwen3-MoE / GLM4):
# Base weights are (E, out_dim, in_dim) for F.linear.
# LoRA weights follow PEFT: weight_A is (E*R, in_dim), weight_B is (out_dim, E*R).
# We need X @ (E, in_dim, R) @ (E, R, out_dim).
# first_weight: (E, in_dim, R) - from lora_A
# second_weight: (E, R, out_dim) - from lora_B
first_weight = weight_A.view(num_experts, rank_per_expert, dim1)
first_weight = first_weight.permute(0, 2, 1).contiguous() # (E, dim1, R)
# second_weight (B): (E, R, out_dim)
second_weight = weight_B.view(dim2, num_experts, rank_per_expert)
second_weight = second_weight.permute(1, 2, 0).contiguous() # (E, R, dim2)
else:
# Non-MoE case: return weights for X @ A.T @ B.T
first_weight = weight_A.T # (dim1, R)
second_weight = weight_B.T # (R, dim2)
return first_weight, second_weight, scaling, num_experts
except Exception:
return None
def _extract_lora_weights(
param, adapter_name: str = "default", num_experts: int = None, experts_module=None
) -> Optional[Tuple[torch.Tensor, torch.Tensor, float]]:
"""
Extract LoRA A and B weights from PEFT ParamWrapper.
This is a compatibility wrapper around _extract_lora_from_wrapper.
Use _extract_lora_from_wrapper directly for new code.
Returns:
(first_weight, second_weight, scaling) for (X @ first) @ second
"""
# This Unsloth Zoo code section is licensed under AGPL3
# Set num_experts on param if provided, so _extract_lora_from_wrapper can use it
if num_experts is not None and not hasattr(param, "num_experts"):
param.num_experts = num_experts
result = _extract_lora_from_wrapper(param, adapter_name, experts_module=experts_module)
if result is None:
return None
# Return first 3 elements (first_weight, second_weight, scaling) without num_experts
return result[0], result[1], result[2]
def _get_base_weight(param):
"""Get base weight from potentially wrapped parameter or module."""
# This Unsloth Zoo code section is licensed under AGPL3
# Recursively unwrap PEFT layers
while hasattr(param, "base_layer"):
param = param.base_layer
if hasattr(param, "get_param"):
return param.get_param()
# Handle Modules (Linear, etc.)
if hasattr(param, "weight"):
return param.weight
return param
def _get_lora_wrapper_for_param(experts_module, param_name):
"""
Get the PEFT ParamWrapper for a specific parameter (gate_up_proj or down_proj).
Uses the explicit key stored in __dict__ if available.
Does NOT lazily setup wrappers as that requires traversing logic not present here.
"""
# This Unsloth Zoo code section is licensed under AGPL3
if hasattr(experts_module, f"{param_name}_lora_wrapper"):
return getattr(experts_module, f"{param_name}_lora_wrapper")
# Check simple attributes if it's directly wrapped
if hasattr(experts_module, param_name):
attr = getattr(experts_module, param_name)
if hasattr(attr, "lora_A"): # Is a ParamWrapper
return attr
return None
def native_moe_grouped_mm(
inputs: torch.Tensor, weight: torch.Tensor, offsets: torch.Tensor
) -> torch.Tensor:
"""
Native implementation using grouped_mm with backward fix.
Uses custom autograd function to avoid PyTorch's grouped_mm backward stride bug.
"""
return _grouped_mm_with_backward_fix(inputs, weight, offsets)
def _apply_lora_grouped_mm(
inputs: torch.Tensor,
lora_B: torch.Tensor,
lora_A: torch.Tensor,
offsets: torch.Tensor,
scaling: float,
grouped_mm_func=native_moe_grouped_mm,
) -> torch.Tensor:
"""
Apply LoRA using grouped GEMM: result = ((X @ B) @ A) * scaling
Args:
inputs: (total_tokens, in_dim)
lora_B: (num_experts, in_dim, rank) - First projection
lora_A: (num_experts, rank, out_dim) - Second projection
offsets: Grouped GEMM offsets
scaling: LoRA scaling factor
grouped_mm_func: Function to use for grouped GEMM (default: native_moe_grouped_mm)
"""
# This Unsloth Zoo code section is licensed under AGPL3
# 1. First Matmul (X @ B)
# lora_B is (E, in_dim, R)
# Native needs (E, in_dim, R) -> No Transpose
lora_intermediate = grouped_mm_func(inputs, lora_B.contiguous(), offsets)
# 2. Second Matmul (result @ A)
# lora_A is (E, R, out_dim)
# Native needs (E, R, out_dim) -> No Transpose
lora_delta = grouped_mm_func(lora_intermediate, lora_A.contiguous(), offsets)
return lora_delta * scaling
def _should_use_separated_lora() -> bool:
"""
Check if separated LoRA approach should be used (default: True).
Set UNSLOTH_MOE_LORA_MERGED=1 to use merged approach instead.
"""
return os.environ.get("UNSLOTH_MOE_LORA_MERGED", "0") != "1"
# ============================================================================
# Model-specific Weight Preprocessing Hooks
# ============================================================================
# Each model can register its own preprocessing function for weight transposition.
# This allows the generic backend to work with different model weight layouts.
_WEIGHT_PREPROCESSORS = {}
def register_weight_preprocessor(model_type: str, preprocessor_fn):
"""
Register a weight preprocessor for a specific model type.
Args:
model_type: Model identifier (e.g., "qwen3_moe", "qwen3_vl_moe")
preprocessor_fn: Function(weight, proj_type, hidden_dim) -> processed_weight
proj_type is "gate_up" or "down"
"""
_WEIGHT_PREPROCESSORS[model_type] = preprocessor_fn
def get_weight_preprocessor(model_type: str):
"""Get registered weight preprocessor for model type."""
return _WEIGHT_PREPROCESSORS.get(model_type)
def preprocess_weight(
weight: torch.Tensor, proj_type: str, hidden_dim: int, model_type=None
):
"""
Preprocess weight tensor for grouped_mm compatibility.
Uses model-specific preprocessor if registered, otherwise uses default logic.
Args:
weight: Weight tensor (E, dim1, dim2) or similar
proj_type: "gate_up" or "down"
hidden_dim: Hidden dimension for shape inference
model_type: Optional model type to use specific preprocessor
Returns:
Weight tensor in (E, in_dim, out_dim) format for grouped_mm
"""
# This Unsloth Zoo code section is licensed under AGPL3
if model_type and model_type in _WEIGHT_PREPROCESSORS:
return _WEIGHT_PREPROCESSORS[model_type](weight, proj_type, hidden_dim)
# Default preprocessing: check if transposition is needed
if proj_type == "gate_up":
# For gate_up, we need (E, hidden_dim, 2*intermediate)
if weight.shape[1] == hidden_dim:
return weight
else:
return weight.transpose(-2, -1)
else: # down
# For down, we need (E, intermediate, hidden_dim)
if weight.shape[2] == hidden_dim:
return weight
else:
return weight.transpose(-2, -1)
# ============================================================================
# Generic MoE Detection and ParamWrapper Patching
# ============================================================================
def _is_moe_experts_module(module) -> bool:
"""
Check if module is an MoE experts layer (generic, not model-specific).
Detects modules with stacked expert weights as 3D nn.Parameter:
- gate_up_proj/down_proj pattern (Qwen3-MoE, Qwen3-VL-MoE, etc.)
- w1/w2/w3 pattern (older MoE models)
"""
# This Unsloth Zoo code section is licensed under AGPL3
import torch.nn as nn
# Check for gate_up_proj pattern
# After PEFT's nn.utils.parametrize wrapping, accessing gate_up_proj
# returns torch.Tensor (not nn.Parameter), so we must accept both.
if hasattr(module, "gate_up_proj"):
param = module.gate_up_proj
# 4-bit parameters are packed into 2D tensors (n_params, 1) or similar.
# Standard MoE weights are 3D (num_experts, in, out).
if isinstance(param, (nn.Parameter, torch.Tensor)) and param.ndim in (2, 3):
return True
# Check for w1/w2 pattern (separate gate/up projections)
if hasattr(module, "w1") and hasattr(module, "w2"):
w1 = module.w1
if isinstance(w1, (nn.Parameter, torch.Tensor)) and w1.ndim in (2, 3):
return True
return False
# Aliases for compatibility with gpt_oss.py
_get_moe_lora_weights = _extract_lora_from_wrapper
# Store original ParamWrapper.forward for fallback
_original_param_wrapper_forward = None
def _patched_param_wrapper_forward(
self, x: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""
Patched ParamWrapper.forward for MoE separated LoRA.
For MoE expert modules:
- Bypasses PEFTs _activate_lora parametrization context
- Stores LoRA data by parameter_name for forward_native_grouped_mm to use
For non-MoE modules:
- Falls back to original PEFT forward
"""
# This Unsloth Zoo code section is licensed under AGPL3
# CRITICAL: Use self.base_layer for forward call (immediate parent)
# NOT self.get_base_layer() which recursively traverses to deepest layer!
# The wrapper chain must be preserved: down_proj -> gate_up_proj -> Qwen3MoeExperts
immediate_base_layer = self.base_layer
# For storing LoRA data, we DO need the actual experts module
# Use get_base_layer() to find it (recursive traversal is correct here)
experts_module = self.get_base_layer()
use_separated = _should_use_separated_lora()
param_name = getattr(self, "parameter_name", None)
# Check if this is an MoE experts module that should use separated LoRA
if (
use_separated
and param_name in ("gate_up_proj", "down_proj")
and _is_moe_experts_module(experts_module)
):
# MoE experts: bypass PEFT's _activate_lora, use separated computation
# Check adapter state
if self.disable_adapters:
if self.merged:
self.unmerge()
return immediate_base_layer(x, *args, **kwargs)
if self.merged:
return immediate_base_layer(x, *args, **kwargs)
# Ensure wrapper.num_experts is set for LoRA weight reshaping
if not hasattr(self, "num_experts"):
if hasattr(experts_module, "num_experts"):
self.num_experts = experts_module.num_experts
elif hasattr(experts_module, param_name):
p = getattr(experts_module, param_name)
if hasattr(p, "shape") and len(p.shape) >= 1:
self.num_experts = p.shape[0]
# Extract LoRA for this specific parameter
lora_data = _extract_lora_from_wrapper(self)
if lora_data is not None and param_name:
# Store LoRA data on the EXPERTS MODULE (not base_layer)
# e.g., _unsloth_lora_gate_up_proj or _unsloth_lora_down_proj
lora_attr = f"_unsloth_lora_{param_name}"
setattr(experts_module, lora_attr, lora_data)
try:
# Call IMMEDIATE base_layer to preserve wrapper chain
# (down_proj wrapper calls gate_up_proj wrapper calls Qwen3MoeExperts)
result = immediate_base_layer(x, *args, **kwargs)
finally:
# Clean up
if param_name:
lora_attr = f"_unsloth_lora_{param_name}"
if hasattr(experts_module, lora_attr):
delattr(experts_module, lora_attr)
return result
# Non-MoE: use original PEFT forward with _activate_lora
return _original_param_wrapper_forward(self, x, *args, **kwargs)
def patch_param_wrapper_for_moe():
"""
Patch PEFT's ParamWrapper.forward to use separated LoRA for MoE.
This should be called after PEFT is imported.
"""
# This Unsloth Zoo code section is licensed under AGPL3
global _original_param_wrapper_forward
module = _load_cached_moe_utils_module()
if module is not None and hasattr(module, "patch_param_wrapper_for_moe"):
try:
return module.patch_param_wrapper_for_moe()
except Exception:
pass
try:
from peft.tuners.lora.layer import ParamWrapper
# Store original forward
if _original_param_wrapper_forward is None:
_original_param_wrapper_forward = ParamWrapper.forward
# Patch with our version
ParamWrapper.forward = _patched_param_wrapper_forward
return True
except ImportError:
return False
def forward_native_grouped_mm(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
"""
Native Pytorch grouped GEMM MoE forward pass.
Uses torch._grouped_mm which is significantly faster than loop and works without Triton dependencies.
Requires torch._grouped_mm support (verified via runtime check).
"""
# This Unsloth Zoo code section is licensed under AGPL3
# Runtime safety check - defense in depth
if not _check_torch_grouped_mm_supported():
major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
raise RuntimeError(
f"torch._grouped_mm is not supported on this device (Compute Capability {major}.{minor}). "
f"Set UNSLOTH_MOE_BACKEND='unsloth_triton' or 'native_torch' to use a compatible backend."
)
is_2d_input = hidden_states.dim() == 2
if is_2d_input:
sequence_length, hidden_dim = hidden_states.shape
batch_size = 1
else:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# 1. Calculate routing
flat_top_k = top_k_index.view(-1)
num_tokens_per_expert = torch.bincount(flat_top_k, minlength=self.num_experts).int()
# 2. Sort indices to group tokens by expert
sorted_indices = torch.argsort(flat_top_k, stable=True)
token_indices = sorted_indices // top_k_index.shape[-1]
# 3. Permute Input
# We need to gather inputs. Since we may have expanded top_k, we use token_indices to map back to original input
permuted_input = hidden_states[token_indices]
# 4. Prepare Grouped MM arguments
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
# ========================================================================
# Gate + Up projection with optional separated LoRA (DEFAULT)
# ========================================================================
use_separated_lora = _should_use_separated_lora()
gate_up_lora = None
# Check for injected LoRA data from patched ParamWrapper (preferred path)
if getattr(self, "_unsloth_lora_gate_up_proj", None) is not None:
gate_up_lora = self._unsloth_lora_gate_up_proj[
:3
] # (first_weight, second_weight, scaling)
# Fallback: check parameter directly (for older wrapping patterns)
elif (
use_separated_lora
and hasattr(self, "gate_up_proj")
and _has_lora_adapters(self.gate_up_proj)
):
gate_up_lora = _extract_lora_weights(
self.gate_up_proj, num_experts=self.num_experts, experts_module=self
)
if hasattr(self, "gate_up_proj"):
# Get base weights (raw, without LoRA)
gate_up_base = _get_base_weight(self.gate_up_proj)
# Get model type for preprocessing (if registered)
model_type = getattr(self, "_unsloth_model_type", None)
# Handle different weight shapes using preprocessor
# torch._grouped_mm backward requires weights to be contiguous; preprocessing may return a transposed view.
w1 = preprocess_weight(gate_up_base, "gate_up", hidden_dim, model_type)
# Base forward: X @ W
mm1_out = _grouped_mm_with_backward_fix(permuted_input, w1, offsets)
# Add separated LoRA contribution: + ((X @ first) @ second) * scaling
# _extract_lora_from_wrapper returns (first_weight, second_weight, scaling)
if gate_up_lora is not None:
first_weight, second_weight, scaling = gate_up_lora
# Cast to input dtype (LoRA weights are float32, input may be bfloat16)
# Ensure contiguous for grouped_mm alignment requirements
first_weight = first_weight.to(permuted_input.dtype).contiguous()
second_weight = second_weight.to(permuted_input.dtype).contiguous()
# Step 1: permuted_input @ first_weight
try:
lora_out = _grouped_mm_with_backward_fix(permuted_input, first_weight, offsets)
lora_out = lora_out.contiguous()
except RuntimeError as e:
raise e
# Step 2: result @ second_weight
# Handle unaligned O dimension or other grouped_mm failures
try:
if second_weight.shape[-1] % 8 != 0:
pad_size = 8 - (second_weight.shape[-1] % 8)
second_weight_padded = F.pad(
second_weight, (0, pad_size)
).contiguous()
lora_delta = _grouped_mm_with_backward_fix(
lora_out, second_weight_padded, offsets
)
lora_delta = lora_delta[:, :-pad_size]
else:
lora_delta = _grouped_mm_with_backward_fix(
lora_out, second_weight, offsets
)
except RuntimeError:
# Fallback to manual loop if grouped_mm fails (e.g. stride alignment)
lora_delta = torch.empty(
(lora_out.shape[0], second_weight.shape[-1]),
dtype=lora_out.dtype,
device=lora_out.device,
)
cpu_offsets = offsets.cpu().tolist()
prev_offset = 0
for i, end in enumerate(cpu_offsets):
if prev_offset < end:
lora_delta[prev_offset:end] = torch.matmul(
lora_out[prev_offset:end], second_weight[i]
)
prev_offset = end
# Add scaled LoRA contribution
mm1_out = mm1_out + lora_delta * scaling
if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None:
num_repeats = num_tokens_per_expert.to(self.gate_up_proj_bias.device)
bias_expanded = self.gate_up_proj_bias.repeat_interleave(num_repeats, dim=0)
mm1_out = mm1_out + bias_expanded.to(mm1_out.dtype)
if "GptOssExperts" in self.__class__.__name__:
gate = mm1_out[..., ::2]
up = mm1_out[..., 1::2]
else:
gate, up = mm1_out.chunk(2, dim=-1)
elif hasattr(self, "w1") and hasattr(self, "w3"):
# Separate w1/w3 weights (older models)
w1_base = _get_base_weight(self.w1)
w3_base = _get_base_weight(self.w3)
w1 = w1_base.transpose(-2, -1)
w3 = w3_base.transpose(-2, -1)
gate = _grouped_mm_with_backward_fix(permuted_input, w1, offsets)
up = _grouped_mm_with_backward_fix(permuted_input, w3, offsets)
# Add LoRA for w1 and w3 separately if present
if use_separated_lora:
if _has_lora_adapters(self.w1):
w1_lora = _extract_lora_weights(self.w1, experts_module=self)
if w1_lora is not None:
lora_A, lora_B, scaling = w1_lora
lora_A_t = lora_A.transpose(-2, -1)
lora_A_out = _grouped_mm_with_backward_fix(
permuted_input, lora_A_t, offsets
)
lora_B_t = lora_B.transpose(-2, -1)
lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets)
gate = gate + lora_B_out * scaling
if _has_lora_adapters(self.w3):
w3_lora = _extract_lora_weights(self.w3, experts_module=self)
if w3_lora is not None:
lora_A, lora_B, scaling = w3_lora
lora_A_t = lora_A.transpose(-2, -1)
lora_A_out = _grouped_mm_with_backward_fix(
permuted_input, lora_A_t, offsets
)
lora_B_t = lora_B.transpose(-2, -1)
lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets)
up = up + lora_B_out * scaling
else:
raise AttributeError("MoE layer must have 'gate_up_proj' or 'w1'/'w3'.")
# Activation
if "GptOssExperts" in self.__class__.__name__:
# Custom activation from GptOss
limit = getattr(self, "limit", 7.0)
alpha = getattr(self, "alpha", 1.702)
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * alpha)
inter = (up + 1.0) * glu
else:
inter = F.silu(gate) * up
# ========================================================================
# Down projection with optional separated LoRA (DEFAULT)
# ========================================================================
down_lora = None
# Check for injected LoRA data from patched ParamWrapper (preferred path)
if getattr(self, "_unsloth_lora_down_proj", None) is not None:
down_lora = self._unsloth_lora_down_proj[
:3
] # (first_weight, second_weight, scaling)
# Fallback: check parameter directly (for older wrapping patterns)
elif (
use_separated_lora
and hasattr(self, "down_proj")
and _has_lora_adapters(self.down_proj)
):
down_lora = _extract_lora_weights(self.down_proj, num_experts=self.num_experts, experts_module=self)
if hasattr(self, "down_proj"):
# Get base weights
down_base = _get_base_weight(self.down_proj)
# Get model type for preprocessing (if registered)
model_type = getattr(self, "_unsloth_model_type", None)
# Handle different weight shapes using preprocessor
w2 = preprocess_weight(down_base, "down", hidden_dim, model_type)
# Base forward
mm2_out = _grouped_mm_with_backward_fix(inter, w2, offsets)
# Add separated LoRA contribution if present
# _extract_lora_from_wrapper returns (first_weight, second_weight, scaling)
if down_lora is not None:
first_weight, second_weight, scaling = down_lora
# Cast to input dtype (LoRA weights are float32, input may be bfloat16)
first_weight = first_weight.to(inter.dtype).contiguous()
second_weight = second_weight.to(inter.dtype).contiguous()
# Step 1: inter @ first_weight
lora_out = _grouped_mm_with_backward_fix(inter, first_weight, offsets)
lora_out = lora_out.contiguous()
# Step 2: result @ second_weight
try:
lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets)
except RuntimeError:
# Fallback to manual loop
lora_delta = torch.empty(
(lora_out.shape[0], second_weight.shape[-1]),
dtype=lora_out.dtype,
device=lora_out.device,
)
cpu_offsets = offsets.cpu().tolist()
prev_offset = 0
for i, end in enumerate(cpu_offsets):
if prev_offset < end:
lora_delta[prev_offset:end] = torch.matmul(
lora_out[prev_offset:end], second_weight[i]
)
prev_offset = end
# Add scaled LoRA contribution
mm2_out = mm2_out + lora_delta * scaling
if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None:
bias_expanded = self.down_proj_bias.repeat_interleave(
num_tokens_per_expert.to(self.down_proj_bias.device), dim=0
).to(mm2_out.device)
mm2_out = mm2_out + bias_expanded.to(mm2_out.dtype)
elif hasattr(self, "w2"):
w2_base = _get_base_weight(self.w2)
w2 = w2_base.transpose(-2, -1)
# Base forward
mm2_out = _grouped_mm_with_backward_fix(inter, w2, offsets)
# Add LoRA if present
if use_separated_lora and _has_lora_adapters(self.w2):
w2_lora = _extract_lora_weights(self.w2, experts_module=self)
if w2_lora is not None:
lora_A, lora_B, scaling = w2_lora
lora_A_t = lora_A.transpose(-2, -1).contiguous()
lora_A_out = _grouped_mm_with_backward_fix(inter, lora_A_t, offsets)
lora_B_t = lora_B.transpose(-2, -1).contiguous()
lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets)
mm2_out = mm2_out + lora_B_out * scaling
else:
raise AttributeError("MoE layer must have 'down_proj' or 'w2'.")
# 5. Apply Routing Weights and Scatter Add (Reduce)
flat_weights = top_k_weights.view(-1)
permuted_weights = flat_weights[sorted_indices]
mm2_out = mm2_out * permuted_weights.unsqueeze(-1)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
final_hidden_states.index_add_(0, token_indices, mm2_out.to(hidden_states.dtype))
if is_2d_input:
return final_hidden_states
return final_hidden_states.view(batch_size, sequence_length, hidden_dim)
def forward_triton_grouped_gemm(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
"""
Grouped GEMM MoE forward pass using Triton kernels.
Compatible with torch.compile (recommended mode="max-autotune" with cudagraph_mark_step_begin).
"""
# This Unsloth Zoo code section is licensed under AGPL3
# Import grouped GEMM interface
from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm
# Import autotune cache
from unsloth.kernels.moe.autotune_cache import get_or_autotune_moe_kernels
# Helper to check TMA support - assumes helper function or just check directly
# In original: it was a cached closure. Here we can use _supports_tma() directly
# nonlocal _MODEL_DIMS_AND_CONFIGS # We need a way to store this!
# For now, let's attach it to self if possible, or use a global usage
# Attaching to self is cleaner: self._unsloth_moe_configs
# Create expert mask and find which experts have tokens
if not hasattr(self, "_unsloth_moe_configs"):
self._unsloth_moe_configs = None
use_separated_lora = _should_use_separated_lora()
# Handle 3D inputs (batch_size, seq_len, hidden_dim)
is_3d = hidden_states.dim() == 3
if is_3d:
batch_size, seq_len, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
num_tokens = batch_size * seq_len
# Also flatten top_k inputs if they are 3D
if top_k_index.dim() == 3:
top_k_index = top_k_index.view(-1, top_k_index.shape[-1])
if top_k_weights.dim() == 3:
top_k_weights = top_k_weights.view(-1, top_k_weights.shape[-1])
else:
num_tokens, hidden_dim = hidden_states.shape
top_k = top_k_index.shape[1]
# Cache model dimensions and kernel configs on first call
if self._unsloth_moe_configs is None:
intermediate_dim = self.gate_up_proj.shape[1] // 2
# Autotune first GEMM
gemm1_configs = get_or_autotune_moe_kernels(
num_experts=self.num_experts,
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim * 2,
top_k=top_k,
dtype=hidden_states.dtype,
)
# Autotune second GEMM
gemm2_configs = get_or_autotune_moe_kernels(
num_experts=self.num_experts,
hidden_dim=intermediate_dim,
intermediate_dim=hidden_dim, # Output dim for 2nd GEMM is hidden_dim
top_k=top_k,
dtype=hidden_states.dtype,
)
self._unsloth_moe_configs = (intermediate_dim, gemm1_configs, gemm2_configs)
# Clear autotuning memory overhead
torch.cuda.empty_cache()
# Unpack cached configs
intermediate_dim, gemm1_configs, gemm2_configs = self._unsloth_moe_configs
# Unpack specific kernel configs
fwd_config_1, bwd_dX_config_1, bwd_dW_config_1 = gemm1_configs
fwd_config_2, bwd_dX_config_2, bwd_dW_config_2 = gemm2_configs
# Compute routing indices for grouped GEMM
token_counts_by_expert, gather_indices = _get_routing_indices(
top_k_index, self.num_experts
)
offsets = torch.cumsum(token_counts_by_expert, dim=0, dtype=torch.int32)
if self.gate_up_proj.shape[-1] == hidden_dim:
w1 = self.gate_up_proj
else:
w1 = self.gate_up_proj.transpose(-2, -1).contiguous()
# First grouped GEMM: gate_up projection
first_gemm_output = grouped_gemm(
X=hidden_states,
W=w1,
m_sizes=token_counts_by_expert,
topk=top_k,
gather_indices=gather_indices,
permute_x=True,
permute_y=False,
autotune=False, # We use cached configs
kernel_config_fwd=fwd_config_1,
kernel_config_bwd_dX=bwd_dX_config_1,
kernel_config_bwd_dW=bwd_dW_config_1,
is_first_gemm=True,
)
# Apply SiLU activation and multiply gate with up
intermediate = _silu_and_mul(first_gemm_output)
# Grouped GEMM 2: down projection
# Grouped GEMM 2: down projection
# Prepare LoRA data
down_lora = None
if getattr(self, "_unsloth_lora_down_proj", None) is not None:
down_lora = self._unsloth_lora_down_proj[:3]
elif (
use_separated_lora
and hasattr(self, "down_proj")
and _has_lora_adapters(self.down_proj)
):
down_lora = _extract_lora_weights(self.down_proj, num_experts=self.num_experts)
if self.down_proj.shape[-1] == intermediate.shape[-1]:
w2 = self.down_proj
else:
w2 = self.down_proj.transpose(-2, -1).contiguous()
second_gemm_output = grouped_gemm(
X=intermediate,
W=w2,
m_sizes=token_counts_by_expert,
topk=top_k,
gather_indices=gather_indices,
permute_x=False,
permute_y=True,
autotune=False, # We use cached configs
kernel_config_fwd=fwd_config_2,
kernel_config_bwd_dX=bwd_dX_config_2,
kernel_config_bwd_dW=bwd_dW_config_2,
is_first_gemm=False,
)
# Add separated LoRA contribution for Down
if down_lora is not None:
first_weight, second_weight, scaling = down_lora
# Intermediate is already permuted from step 1.
# Offsets are same.
first_weight = first_weight.to(intermediate.dtype)
second_weight = second_weight.to(intermediate.dtype)
lora_delta = _apply_lora_grouped_mm(
intermediate,
first_weight,
second_weight,
offsets,
scaling,
grouped_mm_func=native_moe_grouped_mm
)
second_gemm_output = second_gemm_output + lora_delta
# Apply routing weights and sum across top_k experts
# Output shape: (num_tokens, top_k, hidden_dim) -> (num_tokens, hidden_dim)
# Ensure top_k_weights matches dtype (can be float32 from softmax)
top_k_weights_casted = top_k_weights.to(hidden_states.dtype)
final_hidden_states = (
second_gemm_output.view(num_tokens, top_k, hidden_dim)
* top_k_weights_casted[..., None]
)
final_hidden_states = final_hidden_states.sum(dim=1)
if is_3d:
final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim)
return final_hidden_states
@torch.compiler.disable
def forward_native_moe_loop(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
"""
Loop-based MoE forward pass. Loops over experts that have tokens routed to them.
Explicitly disabled for torch.compile to prevent graph breaks/recompilation issues with dynamic control flow.
"""
# This Unsloth Zoo code section is licensed under AGPL3
final_hidden_states = torch.zeros_like(hidden_states)
# Create expert mask and find which experts have tokens
with torch.no_grad():
expert_mask = F.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, n_tokens)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
# Only loop over experts that actually have tokens routed to them
for expert_idx_t in expert_hit:
expert_idx = expert_idx_t.item()
# Find which tokens are routed to this expert
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
# Gather only the tokens for this expert
current_state = hidden_states[token_idx]
# Compute gate_up projection for this expert only
# Handle 'gate_up_proj' or 'w1'/'w3'
if hasattr(self, "gate_up_proj"):
gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk(
2, dim=-1
)
else:
gate = F.linear(current_state, self.w1[expert_idx])
up = F.linear(current_state, self.w3[expert_idx])
current_hidden_states = self.act_fn(gate) * up
# Compute down projection for this expert only
if hasattr(self, "down_proj"):
current_hidden_states = F.linear(
current_hidden_states, self.down_proj[expert_idx]
)
else:
current_hidden_states = F.linear(current_hidden_states, self.w2[expert_idx])
# Apply routing weights
current_hidden_states = (
current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
)
# Scatter back to final output
final_hidden_states.index_add_(
0, token_idx, current_hidden_states.to(final_hidden_states.dtype)
)
return final_hidden_states