Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| # Unsloth Zoo - Utilities for Unsloth | |
| # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. | |
| # | |
| # This program is free software: you can redistribute it and/or modify | |
| # it under the terms of the GNU Affero General Public License as published | |
| # by the Free Software Foundation, either version 3 of the License, or | |
| # (at your option) any later version. | |
| # | |
| # This program is distributed in the hope that it will be useful, | |
| # but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| # GNU Affero General Public License for more details. | |
| # | |
| # You should have received a copy of the GNU Affero General Public License | |
| # along with this program. If not, see <https://www.gnu.org/licenses/>. | |
| import torch | |
| import torch.nn.functional as F | |
| import os | |
| import shutil | |
| import sys | |
| import importlib.util | |
| from typing import Optional, Tuple | |
| from torch.autograd import Function | |
| # Get compile location | |
| UNSLOTH_COMPILE_LOCATION = os.environ.get( | |
| "UNSLOTH_COMPILE_LOCATION", "unsloth_compiled_cache" | |
| ) | |
| def _get_compile_location() -> str: | |
| return os.path.abspath( | |
| os.environ.get("UNSLOTH_COMPILE_LOCATION", UNSLOTH_COMPILE_LOCATION) | |
| ) | |
| def _log_info(message: str): | |
| if os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1": | |
| print(message) | |
| def install_to_cache(source_path, destination_filename=None): | |
| """ | |
| Copies a file to the unsloth_compiled_cache directory | |
| to ensure it is available for compiled modules. | |
| """ | |
| compile_location = _get_compile_location() | |
| if not os.path.exists(compile_location): | |
| try: | |
| os.makedirs(compile_location) | |
| except: | |
| pass | |
| current_file = os.path.abspath(source_path) | |
| if destination_filename is None: | |
| destination_filename = os.path.basename(current_file) | |
| destination = os.path.abspath(os.path.join(compile_location, destination_filename)) | |
| # If source and dest are different, copy. | |
| if current_file != destination: | |
| try: | |
| shutil.copy(current_file, destination) | |
| except Exception: | |
| pass | |
| install_to_cache(__file__, "moe_utils.py") | |
| _CACHED_FORWARD_MOE_BACKEND = None | |
| _CACHED_MOE_UTILS_MODULE = None | |
| def _load_cached_moe_utils_module(): | |
| global _CACHED_MOE_UTILS_MODULE | |
| cache_file = os.path.abspath(os.path.join(_get_compile_location(), "moe_utils.py")) | |
| current_file = os.path.abspath(__file__) | |
| if not os.path.isfile(cache_file) or cache_file == current_file: | |
| return None | |
| try: | |
| module_name = "unsloth_cached_moe_utils" | |
| module = sys.modules.get(module_name, None) | |
| if module is not None and os.path.abspath(getattr(module, "__file__", "")) == cache_file: | |
| _CACHED_MOE_UTILS_MODULE = module | |
| return module | |
| spec = importlib.util.spec_from_file_location(module_name, cache_file) | |
| if spec is None or spec.loader is None: | |
| return None | |
| module = importlib.util.module_from_spec(spec) | |
| sys.modules[module_name] = module | |
| spec.loader.exec_module(module) | |
| _CACHED_MOE_UTILS_MODULE = module | |
| return module | |
| except Exception: | |
| return None | |
| def get_forward_moe_backend(): | |
| """ | |
| Resolve forward_moe_backend from the compiled cache copy when available. | |
| Falls back to the local module definition. | |
| """ | |
| global _CACHED_FORWARD_MOE_BACKEND | |
| module = _load_cached_moe_utils_module() | |
| if module is not None and hasattr(module, "forward_moe_backend"): | |
| _CACHED_FORWARD_MOE_BACKEND = module.forward_moe_backend | |
| return _CACHED_FORWARD_MOE_BACKEND | |
| _CACHED_FORWARD_MOE_BACKEND = forward_moe_backend | |
| return _CACHED_FORWARD_MOE_BACKEND | |
| # ============================================================================ | |
| # Grouped MM wrapper | |
| # ============================================================================ | |
| # Simple wrapper around torch._grouped_mm that ensures contiguous inputs. | |
| # Native backward works correctly - no custom autograd needed. | |
| # ============================================================================ | |
| def _grouped_mm_with_backward_fix( | |
| inputs: torch.Tensor, weight: torch.Tensor, offsets: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Grouped matmul with working backward pass. | |
| Uses native torch._grouped_mm with contiguous inputs for correct gradients. | |
| """ | |
| return torch._grouped_mm(inputs, weight, offs=offsets) | |
| # Global flag to check if grouped GEMM is available | |
| _GROUPED_GEMM_AVAILABLE = None | |
| _TORCH_GROUPED_MM_AVAILABLE = hasattr(torch, "_grouped_mm") | |
| # Check if GPU supports torch._grouped_mm (verified via runtime check) | |
| _TORCH_GROUPED_MM_SUPPORTED = None | |
| def _check_torch_grouped_mm_supported(): | |
| """ | |
| Check if torch._grouped_mm is actually supported on the current GPU. | |
| We check for existence and verify with a dummy call. | |
| A runtime probe is the only reliable check. | |
| """ | |
| global _TORCH_GROUPED_MM_SUPPORTED | |
| if _TORCH_GROUPED_MM_SUPPORTED is not None: return _TORCH_GROUPED_MM_SUPPORTED | |
| if not _TORCH_GROUPED_MM_AVAILABLE: | |
| _TORCH_GROUPED_MM_SUPPORTED = False | |
| return False | |
| if not torch.cuda.is_available(): | |
| _TORCH_GROUPED_MM_SUPPORTED = False | |
| return False | |
| try: | |
| # Attempt a dummy grouped_mm call to verify support. | |
| # This handles cases where the symbol exists but hardware is unsupported (e.g. < H100). | |
| # It also allows support on newer hardware or backports without code changes. | |
| device = torch.cuda.current_device() | |
| dtype = torch.float16 | |
| # Minimal dummy data: 1 expert, 1 token, dim 8 (safe alignment) | |
| x = torch.ones((1, 8), device=device, dtype=dtype) | |
| w = torch.ones((1, 8, 8), device=device, dtype=dtype) | |
| offs = torch.tensor([1], device=device, dtype=torch.int32) | |
| torch._grouped_mm(x, w, offs=offs) | |
| del x, w, offs | |
| _TORCH_GROUPED_MM_SUPPORTED = True | |
| except Exception: | |
| _TORCH_GROUPED_MM_SUPPORTED = False | |
| return _TORCH_GROUPED_MM_SUPPORTED | |
| _TRITON_ALLOCATOR_INITIALIZED = False | |
| _PERSISTENT_BUFFER = None | |
| def _init_triton_allocator(): | |
| """ | |
| Initialize a persistent Triton allocator to avoid memory allocation overhead per call. | |
| This significantly reduces GPU utilization fluctuation. | |
| """ | |
| global _TRITON_ALLOCATOR_INITIALIZED, _PERSISTENT_BUFFER | |
| if _TRITON_ALLOCATOR_INITIALIZED: return | |
| try: | |
| import triton | |
| # Create a persistent buffer that grows as needed | |
| # This avoids allocating new memory on every kernel call | |
| def persistent_alloc_fn(size: int, alignment: int, stream): | |
| global _PERSISTENT_BUFFER | |
| # Round up size to avoid frequent reallocations | |
| # Round to nearest 128 bytes for alignment | |
| rounded_size = ((size + 128 - 1) // 128) * 128 | |
| if ( | |
| _PERSISTENT_BUFFER is None | |
| or _PERSISTENT_BUFFER.numel() * _PERSISTENT_BUFFER.element_size() | |
| < rounded_size | |
| ): | |
| # Allocate with small headroom (10%) to reduce reallocations | |
| # Use ByteTensor (uint8) for raw byte storage | |
| _PERSISTENT_BUFFER = torch.empty( | |
| int(rounded_size * 1.1), device="cuda", dtype=torch.uint8 | |
| ) | |
| _PERSISTENT_BUFFER.__hibernate__ = {"type": "ignore"} | |
| return _PERSISTENT_BUFFER | |
| triton.set_allocator(persistent_alloc_fn) | |
| triton._unsloth_allocator_set = True | |
| _TRITON_ALLOCATOR_INITIALIZED = True | |
| except Exception: | |
| pass | |
| def _check_grouped_gemm_available(): | |
| """Check if Unsloth grouped GEMM kernels are available.""" | |
| if os.environ.get("UNSLOTH_DISABLE_MOE_TRITON", "0") == "1": return False | |
| global _GROUPED_GEMM_AVAILABLE | |
| if _GROUPED_GEMM_AVAILABLE is not None: return _GROUPED_GEMM_AVAILABLE | |
| try: | |
| from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm, supports_tma | |
| _GROUPED_GEMM_AVAILABLE = True | |
| _init_triton_allocator() | |
| except (ImportError, ModuleNotFoundError): | |
| _GROUPED_GEMM_AVAILABLE = False | |
| return _GROUPED_GEMM_AVAILABLE | |
| from functools import lru_cache | |
| def select_moe_backend(): | |
| """ | |
| Selects the MoE backend based on UNSLOTH_MOE_BACKEND environment variable and availability. | |
| Choices: "grouped_mm", "unsloth_triton", "native_torch". | |
| Default if unspecified: "grouped_mm". | |
| """ | |
| # This Unsloth Zoo code section is licensed under AGPL3 | |
| requested = os.environ.get("UNSLOTH_MOE_BACKEND") | |
| if requested: | |
| if requested == "grouped_mm" and _check_torch_grouped_mm_supported(): | |
| return "grouped_mm" | |
| if requested == "unsloth_triton" and _check_grouped_gemm_available(): | |
| return "unsloth_triton" | |
| if requested == "native_torch": | |
| return "native_torch" | |
| _log_info(f"Unsloth: '{requested}' backend requested but is not available. Falling back to next available.") | |
| if _check_torch_grouped_mm_supported(): | |
| _log_info("Unsloth: Using MoE backend 'grouped_mm'") | |
| return "grouped_mm" | |
| if _check_grouped_gemm_available(): | |
| _log_info("Unsloth: Using MoE backend 'unsloth_triton'") | |
| return "unsloth_triton" | |
| return "native_torch" | |
| def forward_moe_backend( | |
| self, | |
| hidden_states: torch.Tensor, | |
| top_k_index: torch.Tensor, | |
| top_k_weights: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Dispatch MoE forward to the selected backend. | |
| Centralizes backend selection to keep model-specific patches minimal. | |
| """ | |
| # This Unsloth Zoo code section is licensed under AGPL3 | |
| backend = select_moe_backend() | |
| if backend == "grouped_mm": | |
| return forward_native_grouped_mm(self, hidden_states, top_k_index, top_k_weights) | |
| if backend == "unsloth_triton": | |
| return forward_triton_grouped_gemm(self, hidden_states, top_k_index, top_k_weights) | |
| return forward_native_moe_loop(self, hidden_states, top_k_index, top_k_weights) | |
| def _get_routing_indices(selected_experts, num_experts): | |
| """ | |
| Compute token→expert mapping for grouped GEMM. | |
| Uses bincount instead of histc to avoid float conversion overhead. | |
| Returns: | |
| token_counts_by_expert: (num_experts,) token counts per expert | |
| gather_indices: (total_tokens,) indices for gathering tokens in expert order | |
| """ | |
| # This Unsloth Zoo code section is licensed under AGPL3 | |
| flat_experts = selected_experts.view(-1) | |
| # bincount is faster than histc since it doesn't require float conversion | |
| token_counts_by_expert = torch.bincount(flat_experts, minlength=num_experts).to(torch.int32) | |
| # argsort with stable=True preserves order within each expert | |
| gather_indices = flat_experts.argsort(stable=True) | |
| return token_counts_by_expert, gather_indices | |
| def _silu_and_mul(x): | |
| """Fused SiLU activation and element-wise multiply for gate/up projections.""" | |
| gate, up = x.chunk(2, dim=-1) | |
| return F.silu(gate) * up | |
| # ============================================================================ | |
| # Separated LoRA Helper Functions | |
| # ============================================================================ | |
| def _has_lora_adapters(param) -> bool: | |
| """Check if parameter has active LoRA adapters (PEFT ParamWrapper).""" | |
| # Check if this is a PEFT LoRA wrapper | |
| if not hasattr(param, "lora_A") or not hasattr(param, "lora_B"): | |
| return False | |
| if hasattr(param, "disable_adapters") and param.disable_adapters: | |
| return False | |
| if hasattr(param, "merged") and param.merged: | |
| return False | |
| return len(param.lora_A) > 0 | |
| def _extract_lora_from_wrapper( | |
| wrapper, adapter_name: str = "default", experts_module=None | |
| ) -> Optional[Tuple[torch.Tensor, torch.Tensor, float, int]]: | |
| """ | |
| Extract LoRA weights from PEFT ParamWrapper for MoE separated computation. | |
| PEFT ParamWrapper for 3D parameters creates: | |
| - lora_A: nn.Linear(in_dim, E*R) -> weight: (E*R, in_dim) | |
| - lora_B: nn.Linear(E*R, out_dim) -> weight: (out_dim, E*R) | |
| For grouped_mm: X @ first_weight @ second_weight | |
| STANDARD FORMAT (Qwen3-MoE): weights stored as (E, out_dim, in_dim) for F.linear | |
| gate_up_proj: (E, 2*I, H) - input X is (N, H), output is (N, 2*I) | |
| down_proj: (E, H, I) - input X is (N, I), output is (N, H) | |
| For gate_up with (E, 2*I, H): | |
| lora_A: (E*R, H), lora_B: (2*I, E*R) | |
| Input X (N, H) needs: X @ (E, H, R) @ (E, R, 2*I) -> (N, 2*I) | |
| first_weight from lora_A: (E*R, H) -> (E, H, R) after view/permute | |
| second_weight from lora_B: (2*I, E*R) -> (E, R, 2*I) after view/permute | |
| TRANSPOSED FORMAT (Qwen3-VL-MoE): weights stored as (E, in_dim, out_dim) for grouped_mm | |
| gate_up_proj: (E, H, 2*I) - input X is (N, H), output is (N, 2*I) | |
| down_proj: (E, I, H) - input X is (N, I), output is (N, H) | |
| For gate_up with (E, H, 2*I): | |
| lora_A: (E*R, H), lora_B: (2*I, E*R) | |
| Input X (N, H) needs: X @ (E, H, R) @ (E, R, 2*I) -> (N, 2*I) | |
| first_weight from lora_A: (E*R, H) -> (E, H, R) | |
| second_weight from lora_B: (2*I, E*R) -> (E, R, 2*I) | |
| Returns: | |
| (first_weight, second_weight, scaling, num_experts) or None | |
| """ | |
| # This Unsloth Zoo code section is licensed under AGPL3 | |
| try: | |
| if not hasattr(wrapper, "lora_A") or not hasattr(wrapper, "lora_B"): | |
| return None | |
| if hasattr(wrapper, "disable_adapters") and wrapper.disable_adapters: | |
| return None | |
| if hasattr(wrapper, "merged") and wrapper.merged: | |
| return None | |
| if not wrapper.lora_A: | |
| return None | |
| if adapter_name not in wrapper.lora_A: | |
| adapter_name = list(wrapper.lora_A.keys())[0] | |
| lora_A_module = wrapper.lora_A[adapter_name] | |
| lora_B_module = wrapper.lora_B[adapter_name] | |
| weight_A = lora_A_module.weight # (E*R, dim1) | |
| weight_B = lora_B_module.weight # (dim2, E*R) | |
| scaling = wrapper.scaling[adapter_name] | |
| num_experts = getattr(wrapper, "num_experts", 1) | |
| # GET EXPERTS MODULE TO CHECK FOR REGISTERED EXTRACTOR | |
| if experts_module is None: | |
| experts_module = wrapper.get_base_layer() if hasattr(wrapper, "get_base_layer") else None | |
| # Check for model-specific LoRA extractor attached to the experts module | |
| extractor_fn = getattr(experts_module, "_unsloth_lora_extractor_fn", None) | |
| if extractor_fn is not None: | |
| return extractor_fn(wrapper, weight_A, weight_B, scaling, num_experts) | |
| # DEFAULT BEHAVIOR (Standard Format / Non-MoE) | |
| if num_experts > 1: | |
| total_rank = weight_A.shape[0] | |
| rank_per_expert = total_rank // num_experts | |
| dim1 = weight_A.shape[1] | |
| dim2 = weight_B.shape[0] | |
| # STANDARD FORMAT (Qwen3-MoE / GLM4): | |
| # Base weights are (E, out_dim, in_dim) for F.linear. | |
| # LoRA weights follow PEFT: weight_A is (E*R, in_dim), weight_B is (out_dim, E*R). | |
| # We need X @ (E, in_dim, R) @ (E, R, out_dim). | |
| # first_weight: (E, in_dim, R) - from lora_A | |
| # second_weight: (E, R, out_dim) - from lora_B | |
| first_weight = weight_A.view(num_experts, rank_per_expert, dim1) | |
| first_weight = first_weight.permute(0, 2, 1).contiguous() # (E, dim1, R) | |
| # second_weight (B): (E, R, out_dim) | |
| second_weight = weight_B.view(dim2, num_experts, rank_per_expert) | |
| second_weight = second_weight.permute(1, 2, 0).contiguous() # (E, R, dim2) | |
| else: | |
| # Non-MoE case: return weights for X @ A.T @ B.T | |
| first_weight = weight_A.T # (dim1, R) | |
| second_weight = weight_B.T # (R, dim2) | |
| return first_weight, second_weight, scaling, num_experts | |
| except Exception: | |
| return None | |
| def _extract_lora_weights( | |
| param, adapter_name: str = "default", num_experts: int = None, experts_module=None | |
| ) -> Optional[Tuple[torch.Tensor, torch.Tensor, float]]: | |
| """ | |
| Extract LoRA A and B weights from PEFT ParamWrapper. | |
| This is a compatibility wrapper around _extract_lora_from_wrapper. | |
| Use _extract_lora_from_wrapper directly for new code. | |
| Returns: | |
| (first_weight, second_weight, scaling) for (X @ first) @ second | |
| """ | |
| # This Unsloth Zoo code section is licensed under AGPL3 | |
| # Set num_experts on param if provided, so _extract_lora_from_wrapper can use it | |
| if num_experts is not None and not hasattr(param, "num_experts"): | |
| param.num_experts = num_experts | |
| result = _extract_lora_from_wrapper(param, adapter_name, experts_module=experts_module) | |
| if result is None: | |
| return None | |
| # Return first 3 elements (first_weight, second_weight, scaling) without num_experts | |
| return result[0], result[1], result[2] | |
| def _get_base_weight(param): | |
| """Get base weight from potentially wrapped parameter or module.""" | |
| # This Unsloth Zoo code section is licensed under AGPL3 | |
| # Recursively unwrap PEFT layers | |
| while hasattr(param, "base_layer"): | |
| param = param.base_layer | |
| if hasattr(param, "get_param"): | |
| return param.get_param() | |
| # Handle Modules (Linear, etc.) | |
| if hasattr(param, "weight"): | |
| return param.weight | |
| return param | |
| def _get_lora_wrapper_for_param(experts_module, param_name): | |
| """ | |
| Get the PEFT ParamWrapper for a specific parameter (gate_up_proj or down_proj). | |
| Uses the explicit key stored in __dict__ if available. | |
| Does NOT lazily setup wrappers as that requires traversing logic not present here. | |
| """ | |
| # This Unsloth Zoo code section is licensed under AGPL3 | |
| if hasattr(experts_module, f"{param_name}_lora_wrapper"): | |
| return getattr(experts_module, f"{param_name}_lora_wrapper") | |
| # Check simple attributes if it's directly wrapped | |
| if hasattr(experts_module, param_name): | |
| attr = getattr(experts_module, param_name) | |
| if hasattr(attr, "lora_A"): # Is a ParamWrapper | |
| return attr | |
| return None | |
| def native_moe_grouped_mm( | |
| inputs: torch.Tensor, weight: torch.Tensor, offsets: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Native implementation using grouped_mm with backward fix. | |
| Uses custom autograd function to avoid PyTorch's grouped_mm backward stride bug. | |
| """ | |
| return _grouped_mm_with_backward_fix(inputs, weight, offsets) | |
| def _apply_lora_grouped_mm( | |
| inputs: torch.Tensor, | |
| lora_B: torch.Tensor, | |
| lora_A: torch.Tensor, | |
| offsets: torch.Tensor, | |
| scaling: float, | |
| grouped_mm_func=native_moe_grouped_mm, | |
| ) -> torch.Tensor: | |
| """ | |
| Apply LoRA using grouped GEMM: result = ((X @ B) @ A) * scaling | |
| Args: | |
| inputs: (total_tokens, in_dim) | |
| lora_B: (num_experts, in_dim, rank) - First projection | |
| lora_A: (num_experts, rank, out_dim) - Second projection | |
| offsets: Grouped GEMM offsets | |
| scaling: LoRA scaling factor | |
| grouped_mm_func: Function to use for grouped GEMM (default: native_moe_grouped_mm) | |
| """ | |
| # This Unsloth Zoo code section is licensed under AGPL3 | |
| # 1. First Matmul (X @ B) | |
| # lora_B is (E, in_dim, R) | |
| # Native needs (E, in_dim, R) -> No Transpose | |
| lora_intermediate = grouped_mm_func(inputs, lora_B.contiguous(), offsets) | |
| # 2. Second Matmul (result @ A) | |
| # lora_A is (E, R, out_dim) | |
| # Native needs (E, R, out_dim) -> No Transpose | |
| lora_delta = grouped_mm_func(lora_intermediate, lora_A.contiguous(), offsets) | |
| return lora_delta * scaling | |
| def _should_use_separated_lora() -> bool: | |
| """ | |
| Check if separated LoRA approach should be used (default: True). | |
| Set UNSLOTH_MOE_LORA_MERGED=1 to use merged approach instead. | |
| """ | |
| return os.environ.get("UNSLOTH_MOE_LORA_MERGED", "0") != "1" | |
| # ============================================================================ | |
| # Model-specific Weight Preprocessing Hooks | |
| # ============================================================================ | |
| # Each model can register its own preprocessing function for weight transposition. | |
| # This allows the generic backend to work with different model weight layouts. | |
| _WEIGHT_PREPROCESSORS = {} | |
| def register_weight_preprocessor(model_type: str, preprocessor_fn): | |
| """ | |
| Register a weight preprocessor for a specific model type. | |
| Args: | |
| model_type: Model identifier (e.g., "qwen3_moe", "qwen3_vl_moe") | |
| preprocessor_fn: Function(weight, proj_type, hidden_dim) -> processed_weight | |
| proj_type is "gate_up" or "down" | |
| """ | |
| _WEIGHT_PREPROCESSORS[model_type] = preprocessor_fn | |
| def get_weight_preprocessor(model_type: str): | |
| """Get registered weight preprocessor for model type.""" | |
| return _WEIGHT_PREPROCESSORS.get(model_type) | |
| def preprocess_weight( | |
| weight: torch.Tensor, proj_type: str, hidden_dim: int, model_type=None | |
| ): | |
| """ | |
| Preprocess weight tensor for grouped_mm compatibility. | |
| Uses model-specific preprocessor if registered, otherwise uses default logic. | |
| Args: | |
| weight: Weight tensor (E, dim1, dim2) or similar | |
| proj_type: "gate_up" or "down" | |
| hidden_dim: Hidden dimension for shape inference | |
| model_type: Optional model type to use specific preprocessor | |
| Returns: | |
| Weight tensor in (E, in_dim, out_dim) format for grouped_mm | |
| """ | |
| # This Unsloth Zoo code section is licensed under AGPL3 | |
| if model_type and model_type in _WEIGHT_PREPROCESSORS: | |
| return _WEIGHT_PREPROCESSORS[model_type](weight, proj_type, hidden_dim) | |
| # Default preprocessing: check if transposition is needed | |
| if proj_type == "gate_up": | |
| # For gate_up, we need (E, hidden_dim, 2*intermediate) | |
| if weight.shape[1] == hidden_dim: | |
| return weight | |
| else: | |
| return weight.transpose(-2, -1) | |
| else: # down | |
| # For down, we need (E, intermediate, hidden_dim) | |
| if weight.shape[2] == hidden_dim: | |
| return weight | |
| else: | |
| return weight.transpose(-2, -1) | |
| # ============================================================================ | |
| # Generic MoE Detection and ParamWrapper Patching | |
| # ============================================================================ | |
| def _is_moe_experts_module(module) -> bool: | |
| """ | |
| Check if module is an MoE experts layer (generic, not model-specific). | |
| Detects modules with stacked expert weights as 3D nn.Parameter: | |
| - gate_up_proj/down_proj pattern (Qwen3-MoE, Qwen3-VL-MoE, etc.) | |
| - w1/w2/w3 pattern (older MoE models) | |
| """ | |
| # This Unsloth Zoo code section is licensed under AGPL3 | |
| import torch.nn as nn | |
| # Check for gate_up_proj pattern | |
| # After PEFT's nn.utils.parametrize wrapping, accessing gate_up_proj | |
| # returns torch.Tensor (not nn.Parameter), so we must accept both. | |
| if hasattr(module, "gate_up_proj"): | |
| param = module.gate_up_proj | |
| # 4-bit parameters are packed into 2D tensors (n_params, 1) or similar. | |
| # Standard MoE weights are 3D (num_experts, in, out). | |
| if isinstance(param, (nn.Parameter, torch.Tensor)) and param.ndim in (2, 3): | |
| return True | |
| # Check for w1/w2 pattern (separate gate/up projections) | |
| if hasattr(module, "w1") and hasattr(module, "w2"): | |
| w1 = module.w1 | |
| if isinstance(w1, (nn.Parameter, torch.Tensor)) and w1.ndim in (2, 3): | |
| return True | |
| return False | |
| # Aliases for compatibility with gpt_oss.py | |
| _get_moe_lora_weights = _extract_lora_from_wrapper | |
| # Store original ParamWrapper.forward for fallback | |
| _original_param_wrapper_forward = None | |
| def _patched_param_wrapper_forward( | |
| self, x: torch.Tensor, *args, **kwargs | |
| ) -> torch.Tensor: | |
| """ | |
| Patched ParamWrapper.forward for MoE separated LoRA. | |
| For MoE expert modules: | |
| - Bypasses PEFTs _activate_lora parametrization context | |
| - Stores LoRA data by parameter_name for forward_native_grouped_mm to use | |
| For non-MoE modules: | |
| - Falls back to original PEFT forward | |
| """ | |
| # This Unsloth Zoo code section is licensed under AGPL3 | |
| # CRITICAL: Use self.base_layer for forward call (immediate parent) | |
| # NOT self.get_base_layer() which recursively traverses to deepest layer! | |
| # The wrapper chain must be preserved: down_proj -> gate_up_proj -> Qwen3MoeExperts | |
| immediate_base_layer = self.base_layer | |
| # For storing LoRA data, we DO need the actual experts module | |
| # Use get_base_layer() to find it (recursive traversal is correct here) | |
| experts_module = self.get_base_layer() | |
| use_separated = _should_use_separated_lora() | |
| param_name = getattr(self, "parameter_name", None) | |
| # Check if this is an MoE experts module that should use separated LoRA | |
| if ( | |
| use_separated | |
| and param_name in ("gate_up_proj", "down_proj") | |
| and _is_moe_experts_module(experts_module) | |
| ): | |
| # MoE experts: bypass PEFT's _activate_lora, use separated computation | |
| # Check adapter state | |
| if self.disable_adapters: | |
| if self.merged: | |
| self.unmerge() | |
| return immediate_base_layer(x, *args, **kwargs) | |
| if self.merged: | |
| return immediate_base_layer(x, *args, **kwargs) | |
| # Ensure wrapper.num_experts is set for LoRA weight reshaping | |
| if not hasattr(self, "num_experts"): | |
| if hasattr(experts_module, "num_experts"): | |
| self.num_experts = experts_module.num_experts | |
| elif hasattr(experts_module, param_name): | |
| p = getattr(experts_module, param_name) | |
| if hasattr(p, "shape") and len(p.shape) >= 1: | |
| self.num_experts = p.shape[0] | |
| # Extract LoRA for this specific parameter | |
| lora_data = _extract_lora_from_wrapper(self) | |
| if lora_data is not None and param_name: | |
| # Store LoRA data on the EXPERTS MODULE (not base_layer) | |
| # e.g., _unsloth_lora_gate_up_proj or _unsloth_lora_down_proj | |
| lora_attr = f"_unsloth_lora_{param_name}" | |
| setattr(experts_module, lora_attr, lora_data) | |
| try: | |
| # Call IMMEDIATE base_layer to preserve wrapper chain | |
| # (down_proj wrapper calls gate_up_proj wrapper calls Qwen3MoeExperts) | |
| result = immediate_base_layer(x, *args, **kwargs) | |
| finally: | |
| # Clean up | |
| if param_name: | |
| lora_attr = f"_unsloth_lora_{param_name}" | |
| if hasattr(experts_module, lora_attr): | |
| delattr(experts_module, lora_attr) | |
| return result | |
| # Non-MoE: use original PEFT forward with _activate_lora | |
| return _original_param_wrapper_forward(self, x, *args, **kwargs) | |
| def patch_param_wrapper_for_moe(): | |
| """ | |
| Patch PEFT's ParamWrapper.forward to use separated LoRA for MoE. | |
| This should be called after PEFT is imported. | |
| """ | |
| # This Unsloth Zoo code section is licensed under AGPL3 | |
| global _original_param_wrapper_forward | |
| module = _load_cached_moe_utils_module() | |
| if module is not None and hasattr(module, "patch_param_wrapper_for_moe"): | |
| try: | |
| return module.patch_param_wrapper_for_moe() | |
| except Exception: | |
| pass | |
| try: | |
| from peft.tuners.lora.layer import ParamWrapper | |
| # Store original forward | |
| if _original_param_wrapper_forward is None: | |
| _original_param_wrapper_forward = ParamWrapper.forward | |
| # Patch with our version | |
| ParamWrapper.forward = _patched_param_wrapper_forward | |
| return True | |
| except ImportError: | |
| return False | |
| def forward_native_grouped_mm( | |
| self, | |
| hidden_states: torch.Tensor, | |
| top_k_index: torch.Tensor, | |
| top_k_weights: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Native Pytorch grouped GEMM MoE forward pass. | |
| Uses torch._grouped_mm which is significantly faster than loop and works without Triton dependencies. | |
| Requires torch._grouped_mm support (verified via runtime check). | |
| """ | |
| # This Unsloth Zoo code section is licensed under AGPL3 | |
| # Runtime safety check - defense in depth | |
| if not _check_torch_grouped_mm_supported(): | |
| major, minor = torch.cuda.get_device_capability(torch.cuda.current_device()) | |
| raise RuntimeError( | |
| f"torch._grouped_mm is not supported on this device (Compute Capability {major}.{minor}). " | |
| f"Set UNSLOTH_MOE_BACKEND='unsloth_triton' or 'native_torch' to use a compatible backend." | |
| ) | |
| is_2d_input = hidden_states.dim() == 2 | |
| if is_2d_input: | |
| sequence_length, hidden_dim = hidden_states.shape | |
| batch_size = 1 | |
| else: | |
| batch_size, sequence_length, hidden_dim = hidden_states.shape | |
| hidden_states = hidden_states.view(-1, hidden_dim) | |
| # 1. Calculate routing | |
| flat_top_k = top_k_index.view(-1) | |
| num_tokens_per_expert = torch.bincount(flat_top_k, minlength=self.num_experts).int() | |
| # 2. Sort indices to group tokens by expert | |
| sorted_indices = torch.argsort(flat_top_k, stable=True) | |
| token_indices = sorted_indices // top_k_index.shape[-1] | |
| # 3. Permute Input | |
| # We need to gather inputs. Since we may have expanded top_k, we use token_indices to map back to original input | |
| permuted_input = hidden_states[token_indices] | |
| # 4. Prepare Grouped MM arguments | |
| offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) | |
| # ======================================================================== | |
| # Gate + Up projection with optional separated LoRA (DEFAULT) | |
| # ======================================================================== | |
| use_separated_lora = _should_use_separated_lora() | |
| gate_up_lora = None | |
| # Check for injected LoRA data from patched ParamWrapper (preferred path) | |
| if getattr(self, "_unsloth_lora_gate_up_proj", None) is not None: | |
| gate_up_lora = self._unsloth_lora_gate_up_proj[ | |
| :3 | |
| ] # (first_weight, second_weight, scaling) | |
| # Fallback: check parameter directly (for older wrapping patterns) | |
| elif ( | |
| use_separated_lora | |
| and hasattr(self, "gate_up_proj") | |
| and _has_lora_adapters(self.gate_up_proj) | |
| ): | |
| gate_up_lora = _extract_lora_weights( | |
| self.gate_up_proj, num_experts=self.num_experts, experts_module=self | |
| ) | |
| if hasattr(self, "gate_up_proj"): | |
| # Get base weights (raw, without LoRA) | |
| gate_up_base = _get_base_weight(self.gate_up_proj) | |
| # Get model type for preprocessing (if registered) | |
| model_type = getattr(self, "_unsloth_model_type", None) | |
| # Handle different weight shapes using preprocessor | |
| # torch._grouped_mm backward requires weights to be contiguous; preprocessing may return a transposed view. | |
| w1 = preprocess_weight(gate_up_base, "gate_up", hidden_dim, model_type) | |
| # Base forward: X @ W | |
| mm1_out = _grouped_mm_with_backward_fix(permuted_input, w1, offsets) | |
| # Add separated LoRA contribution: + ((X @ first) @ second) * scaling | |
| # _extract_lora_from_wrapper returns (first_weight, second_weight, scaling) | |
| if gate_up_lora is not None: | |
| first_weight, second_weight, scaling = gate_up_lora | |
| # Cast to input dtype (LoRA weights are float32, input may be bfloat16) | |
| # Ensure contiguous for grouped_mm alignment requirements | |
| first_weight = first_weight.to(permuted_input.dtype).contiguous() | |
| second_weight = second_weight.to(permuted_input.dtype).contiguous() | |
| # Step 1: permuted_input @ first_weight | |
| try: | |
| lora_out = _grouped_mm_with_backward_fix(permuted_input, first_weight, offsets) | |
| lora_out = lora_out.contiguous() | |
| except RuntimeError as e: | |
| raise e | |
| # Step 2: result @ second_weight | |
| # Handle unaligned O dimension or other grouped_mm failures | |
| try: | |
| if second_weight.shape[-1] % 8 != 0: | |
| pad_size = 8 - (second_weight.shape[-1] % 8) | |
| second_weight_padded = F.pad( | |
| second_weight, (0, pad_size) | |
| ).contiguous() | |
| lora_delta = _grouped_mm_with_backward_fix( | |
| lora_out, second_weight_padded, offsets | |
| ) | |
| lora_delta = lora_delta[:, :-pad_size] | |
| else: | |
| lora_delta = _grouped_mm_with_backward_fix( | |
| lora_out, second_weight, offsets | |
| ) | |
| except RuntimeError: | |
| # Fallback to manual loop if grouped_mm fails (e.g. stride alignment) | |
| lora_delta = torch.empty( | |
| (lora_out.shape[0], second_weight.shape[-1]), | |
| dtype=lora_out.dtype, | |
| device=lora_out.device, | |
| ) | |
| cpu_offsets = offsets.cpu().tolist() | |
| prev_offset = 0 | |
| for i, end in enumerate(cpu_offsets): | |
| if prev_offset < end: | |
| lora_delta[prev_offset:end] = torch.matmul( | |
| lora_out[prev_offset:end], second_weight[i] | |
| ) | |
| prev_offset = end | |
| # Add scaled LoRA contribution | |
| mm1_out = mm1_out + lora_delta * scaling | |
| if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None: | |
| num_repeats = num_tokens_per_expert.to(self.gate_up_proj_bias.device) | |
| bias_expanded = self.gate_up_proj_bias.repeat_interleave(num_repeats, dim=0) | |
| mm1_out = mm1_out + bias_expanded.to(mm1_out.dtype) | |
| if "GptOssExperts" in self.__class__.__name__: | |
| gate = mm1_out[..., ::2] | |
| up = mm1_out[..., 1::2] | |
| else: | |
| gate, up = mm1_out.chunk(2, dim=-1) | |
| elif hasattr(self, "w1") and hasattr(self, "w3"): | |
| # Separate w1/w3 weights (older models) | |
| w1_base = _get_base_weight(self.w1) | |
| w3_base = _get_base_weight(self.w3) | |
| w1 = w1_base.transpose(-2, -1) | |
| w3 = w3_base.transpose(-2, -1) | |
| gate = _grouped_mm_with_backward_fix(permuted_input, w1, offsets) | |
| up = _grouped_mm_with_backward_fix(permuted_input, w3, offsets) | |
| # Add LoRA for w1 and w3 separately if present | |
| if use_separated_lora: | |
| if _has_lora_adapters(self.w1): | |
| w1_lora = _extract_lora_weights(self.w1, experts_module=self) | |
| if w1_lora is not None: | |
| lora_A, lora_B, scaling = w1_lora | |
| lora_A_t = lora_A.transpose(-2, -1) | |
| lora_A_out = _grouped_mm_with_backward_fix( | |
| permuted_input, lora_A_t, offsets | |
| ) | |
| lora_B_t = lora_B.transpose(-2, -1) | |
| lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets) | |
| gate = gate + lora_B_out * scaling | |
| if _has_lora_adapters(self.w3): | |
| w3_lora = _extract_lora_weights(self.w3, experts_module=self) | |
| if w3_lora is not None: | |
| lora_A, lora_B, scaling = w3_lora | |
| lora_A_t = lora_A.transpose(-2, -1) | |
| lora_A_out = _grouped_mm_with_backward_fix( | |
| permuted_input, lora_A_t, offsets | |
| ) | |
| lora_B_t = lora_B.transpose(-2, -1) | |
| lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets) | |
| up = up + lora_B_out * scaling | |
| else: | |
| raise AttributeError("MoE layer must have 'gate_up_proj' or 'w1'/'w3'.") | |
| # Activation | |
| if "GptOssExperts" in self.__class__.__name__: | |
| # Custom activation from GptOss | |
| limit = getattr(self, "limit", 7.0) | |
| alpha = getattr(self, "alpha", 1.702) | |
| gate = gate.clamp(min=None, max=limit) | |
| up = up.clamp(min=-limit, max=limit) | |
| glu = gate * torch.sigmoid(gate * alpha) | |
| inter = (up + 1.0) * glu | |
| else: | |
| inter = F.silu(gate) * up | |
| # ======================================================================== | |
| # Down projection with optional separated LoRA (DEFAULT) | |
| # ======================================================================== | |
| down_lora = None | |
| # Check for injected LoRA data from patched ParamWrapper (preferred path) | |
| if getattr(self, "_unsloth_lora_down_proj", None) is not None: | |
| down_lora = self._unsloth_lora_down_proj[ | |
| :3 | |
| ] # (first_weight, second_weight, scaling) | |
| # Fallback: check parameter directly (for older wrapping patterns) | |
| elif ( | |
| use_separated_lora | |
| and hasattr(self, "down_proj") | |
| and _has_lora_adapters(self.down_proj) | |
| ): | |
| down_lora = _extract_lora_weights(self.down_proj, num_experts=self.num_experts, experts_module=self) | |
| if hasattr(self, "down_proj"): | |
| # Get base weights | |
| down_base = _get_base_weight(self.down_proj) | |
| # Get model type for preprocessing (if registered) | |
| model_type = getattr(self, "_unsloth_model_type", None) | |
| # Handle different weight shapes using preprocessor | |
| w2 = preprocess_weight(down_base, "down", hidden_dim, model_type) | |
| # Base forward | |
| mm2_out = _grouped_mm_with_backward_fix(inter, w2, offsets) | |
| # Add separated LoRA contribution if present | |
| # _extract_lora_from_wrapper returns (first_weight, second_weight, scaling) | |
| if down_lora is not None: | |
| first_weight, second_weight, scaling = down_lora | |
| # Cast to input dtype (LoRA weights are float32, input may be bfloat16) | |
| first_weight = first_weight.to(inter.dtype).contiguous() | |
| second_weight = second_weight.to(inter.dtype).contiguous() | |
| # Step 1: inter @ first_weight | |
| lora_out = _grouped_mm_with_backward_fix(inter, first_weight, offsets) | |
| lora_out = lora_out.contiguous() | |
| # Step 2: result @ second_weight | |
| try: | |
| lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets) | |
| except RuntimeError: | |
| # Fallback to manual loop | |
| lora_delta = torch.empty( | |
| (lora_out.shape[0], second_weight.shape[-1]), | |
| dtype=lora_out.dtype, | |
| device=lora_out.device, | |
| ) | |
| cpu_offsets = offsets.cpu().tolist() | |
| prev_offset = 0 | |
| for i, end in enumerate(cpu_offsets): | |
| if prev_offset < end: | |
| lora_delta[prev_offset:end] = torch.matmul( | |
| lora_out[prev_offset:end], second_weight[i] | |
| ) | |
| prev_offset = end | |
| # Add scaled LoRA contribution | |
| mm2_out = mm2_out + lora_delta * scaling | |
| if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None: | |
| bias_expanded = self.down_proj_bias.repeat_interleave( | |
| num_tokens_per_expert.to(self.down_proj_bias.device), dim=0 | |
| ).to(mm2_out.device) | |
| mm2_out = mm2_out + bias_expanded.to(mm2_out.dtype) | |
| elif hasattr(self, "w2"): | |
| w2_base = _get_base_weight(self.w2) | |
| w2 = w2_base.transpose(-2, -1) | |
| # Base forward | |
| mm2_out = _grouped_mm_with_backward_fix(inter, w2, offsets) | |
| # Add LoRA if present | |
| if use_separated_lora and _has_lora_adapters(self.w2): | |
| w2_lora = _extract_lora_weights(self.w2, experts_module=self) | |
| if w2_lora is not None: | |
| lora_A, lora_B, scaling = w2_lora | |
| lora_A_t = lora_A.transpose(-2, -1).contiguous() | |
| lora_A_out = _grouped_mm_with_backward_fix(inter, lora_A_t, offsets) | |
| lora_B_t = lora_B.transpose(-2, -1).contiguous() | |
| lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets) | |
| mm2_out = mm2_out + lora_B_out * scaling | |
| else: | |
| raise AttributeError("MoE layer must have 'down_proj' or 'w2'.") | |
| # 5. Apply Routing Weights and Scatter Add (Reduce) | |
| flat_weights = top_k_weights.view(-1) | |
| permuted_weights = flat_weights[sorted_indices] | |
| mm2_out = mm2_out * permuted_weights.unsqueeze(-1) | |
| final_hidden_states = torch.zeros( | |
| (batch_size * sequence_length, hidden_dim), | |
| dtype=hidden_states.dtype, | |
| device=hidden_states.device, | |
| ) | |
| final_hidden_states.index_add_(0, token_indices, mm2_out.to(hidden_states.dtype)) | |
| if is_2d_input: | |
| return final_hidden_states | |
| return final_hidden_states.view(batch_size, sequence_length, hidden_dim) | |
| def forward_triton_grouped_gemm( | |
| self, | |
| hidden_states: torch.Tensor, | |
| top_k_index: torch.Tensor, | |
| top_k_weights: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Grouped GEMM MoE forward pass using Triton kernels. | |
| Compatible with torch.compile (recommended mode="max-autotune" with cudagraph_mark_step_begin). | |
| """ | |
| # This Unsloth Zoo code section is licensed under AGPL3 | |
| # Import grouped GEMM interface | |
| from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm | |
| # Import autotune cache | |
| from unsloth.kernels.moe.autotune_cache import get_or_autotune_moe_kernels | |
| # Helper to check TMA support - assumes helper function or just check directly | |
| # In original: it was a cached closure. Here we can use _supports_tma() directly | |
| # nonlocal _MODEL_DIMS_AND_CONFIGS # We need a way to store this! | |
| # For now, let's attach it to self if possible, or use a global usage | |
| # Attaching to self is cleaner: self._unsloth_moe_configs | |
| # Create expert mask and find which experts have tokens | |
| if not hasattr(self, "_unsloth_moe_configs"): | |
| self._unsloth_moe_configs = None | |
| use_separated_lora = _should_use_separated_lora() | |
| # Handle 3D inputs (batch_size, seq_len, hidden_dim) | |
| is_3d = hidden_states.dim() == 3 | |
| if is_3d: | |
| batch_size, seq_len, hidden_dim = hidden_states.shape | |
| hidden_states = hidden_states.view(-1, hidden_dim) | |
| num_tokens = batch_size * seq_len | |
| # Also flatten top_k inputs if they are 3D | |
| if top_k_index.dim() == 3: | |
| top_k_index = top_k_index.view(-1, top_k_index.shape[-1]) | |
| if top_k_weights.dim() == 3: | |
| top_k_weights = top_k_weights.view(-1, top_k_weights.shape[-1]) | |
| else: | |
| num_tokens, hidden_dim = hidden_states.shape | |
| top_k = top_k_index.shape[1] | |
| # Cache model dimensions and kernel configs on first call | |
| if self._unsloth_moe_configs is None: | |
| intermediate_dim = self.gate_up_proj.shape[1] // 2 | |
| # Autotune first GEMM | |
| gemm1_configs = get_or_autotune_moe_kernels( | |
| num_experts=self.num_experts, | |
| hidden_dim=hidden_dim, | |
| intermediate_dim=intermediate_dim * 2, | |
| top_k=top_k, | |
| dtype=hidden_states.dtype, | |
| ) | |
| # Autotune second GEMM | |
| gemm2_configs = get_or_autotune_moe_kernels( | |
| num_experts=self.num_experts, | |
| hidden_dim=intermediate_dim, | |
| intermediate_dim=hidden_dim, # Output dim for 2nd GEMM is hidden_dim | |
| top_k=top_k, | |
| dtype=hidden_states.dtype, | |
| ) | |
| self._unsloth_moe_configs = (intermediate_dim, gemm1_configs, gemm2_configs) | |
| # Clear autotuning memory overhead | |
| torch.cuda.empty_cache() | |
| # Unpack cached configs | |
| intermediate_dim, gemm1_configs, gemm2_configs = self._unsloth_moe_configs | |
| # Unpack specific kernel configs | |
| fwd_config_1, bwd_dX_config_1, bwd_dW_config_1 = gemm1_configs | |
| fwd_config_2, bwd_dX_config_2, bwd_dW_config_2 = gemm2_configs | |
| # Compute routing indices for grouped GEMM | |
| token_counts_by_expert, gather_indices = _get_routing_indices( | |
| top_k_index, self.num_experts | |
| ) | |
| offsets = torch.cumsum(token_counts_by_expert, dim=0, dtype=torch.int32) | |
| if self.gate_up_proj.shape[-1] == hidden_dim: | |
| w1 = self.gate_up_proj | |
| else: | |
| w1 = self.gate_up_proj.transpose(-2, -1).contiguous() | |
| # First grouped GEMM: gate_up projection | |
| first_gemm_output = grouped_gemm( | |
| X=hidden_states, | |
| W=w1, | |
| m_sizes=token_counts_by_expert, | |
| topk=top_k, | |
| gather_indices=gather_indices, | |
| permute_x=True, | |
| permute_y=False, | |
| autotune=False, # We use cached configs | |
| kernel_config_fwd=fwd_config_1, | |
| kernel_config_bwd_dX=bwd_dX_config_1, | |
| kernel_config_bwd_dW=bwd_dW_config_1, | |
| is_first_gemm=True, | |
| ) | |
| # Apply SiLU activation and multiply gate with up | |
| intermediate = _silu_and_mul(first_gemm_output) | |
| # Grouped GEMM 2: down projection | |
| # Grouped GEMM 2: down projection | |
| # Prepare LoRA data | |
| down_lora = None | |
| if getattr(self, "_unsloth_lora_down_proj", None) is not None: | |
| down_lora = self._unsloth_lora_down_proj[:3] | |
| elif ( | |
| use_separated_lora | |
| and hasattr(self, "down_proj") | |
| and _has_lora_adapters(self.down_proj) | |
| ): | |
| down_lora = _extract_lora_weights(self.down_proj, num_experts=self.num_experts) | |
| if self.down_proj.shape[-1] == intermediate.shape[-1]: | |
| w2 = self.down_proj | |
| else: | |
| w2 = self.down_proj.transpose(-2, -1).contiguous() | |
| second_gemm_output = grouped_gemm( | |
| X=intermediate, | |
| W=w2, | |
| m_sizes=token_counts_by_expert, | |
| topk=top_k, | |
| gather_indices=gather_indices, | |
| permute_x=False, | |
| permute_y=True, | |
| autotune=False, # We use cached configs | |
| kernel_config_fwd=fwd_config_2, | |
| kernel_config_bwd_dX=bwd_dX_config_2, | |
| kernel_config_bwd_dW=bwd_dW_config_2, | |
| is_first_gemm=False, | |
| ) | |
| # Add separated LoRA contribution for Down | |
| if down_lora is not None: | |
| first_weight, second_weight, scaling = down_lora | |
| # Intermediate is already permuted from step 1. | |
| # Offsets are same. | |
| first_weight = first_weight.to(intermediate.dtype) | |
| second_weight = second_weight.to(intermediate.dtype) | |
| lora_delta = _apply_lora_grouped_mm( | |
| intermediate, | |
| first_weight, | |
| second_weight, | |
| offsets, | |
| scaling, | |
| grouped_mm_func=native_moe_grouped_mm | |
| ) | |
| second_gemm_output = second_gemm_output + lora_delta | |
| # Apply routing weights and sum across top_k experts | |
| # Output shape: (num_tokens, top_k, hidden_dim) -> (num_tokens, hidden_dim) | |
| # Ensure top_k_weights matches dtype (can be float32 from softmax) | |
| top_k_weights_casted = top_k_weights.to(hidden_states.dtype) | |
| final_hidden_states = ( | |
| second_gemm_output.view(num_tokens, top_k, hidden_dim) | |
| * top_k_weights_casted[..., None] | |
| ) | |
| final_hidden_states = final_hidden_states.sum(dim=1) | |
| if is_3d: | |
| final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim) | |
| return final_hidden_states | |
| def forward_native_moe_loop( | |
| self, | |
| hidden_states: torch.Tensor, | |
| top_k_index: torch.Tensor, | |
| top_k_weights: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Loop-based MoE forward pass. Loops over experts that have tokens routed to them. | |
| Explicitly disabled for torch.compile to prevent graph breaks/recompilation issues with dynamic control flow. | |
| """ | |
| # This Unsloth Zoo code section is licensed under AGPL3 | |
| final_hidden_states = torch.zeros_like(hidden_states) | |
| # Create expert mask and find which experts have tokens | |
| with torch.no_grad(): | |
| expert_mask = F.one_hot(top_k_index, num_classes=self.num_experts) | |
| expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, n_tokens) | |
| expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | |
| # Only loop over experts that actually have tokens routed to them | |
| for expert_idx_t in expert_hit: | |
| expert_idx = expert_idx_t.item() | |
| # Find which tokens are routed to this expert | |
| top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) | |
| # Gather only the tokens for this expert | |
| current_state = hidden_states[token_idx] | |
| # Compute gate_up projection for this expert only | |
| # Handle 'gate_up_proj' or 'w1'/'w3' | |
| if hasattr(self, "gate_up_proj"): | |
| gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk( | |
| 2, dim=-1 | |
| ) | |
| else: | |
| gate = F.linear(current_state, self.w1[expert_idx]) | |
| up = F.linear(current_state, self.w3[expert_idx]) | |
| current_hidden_states = self.act_fn(gate) * up | |
| # Compute down projection for this expert only | |
| if hasattr(self, "down_proj"): | |
| current_hidden_states = F.linear( | |
| current_hidden_states, self.down_proj[expert_idx] | |
| ) | |
| else: | |
| current_hidden_states = F.linear(current_hidden_states, self.w2[expert_idx]) | |
| # Apply routing weights | |
| current_hidden_states = ( | |
| current_hidden_states * top_k_weights[token_idx, top_k_pos, None] | |
| ) | |
| # Scatter back to final output | |
| final_hidden_states.index_add_( | |
| 0, token_idx, current_hidden_states.to(final_hidden_states.dtype) | |
| ) | |
| return final_hidden_states | |