moondream3-preview-hf / modeling_moondream3_fusedmoe.py
NyxKrage's picture
1.0
ca700c7 verified
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Tuple, Union
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.masking_utils import create_causal_mask
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.processing_utils import Unpack
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, load_state_dict
from transformers.generation import GenerationMixin
from transformers.utils import logging, TransformersKwargs
from .moondream3_moe_fused.moe_fused_linear import MoeFusedLinear
from .moondream3_moe_fused.kernels.indexing import get_expert_counts_and_idx
from .configuration_moondream3 import Moondream3Config, Moondream3TextConfig, Moondream3VisionConfig, Moondream3RegionConfig
from . import modeling_moondream3
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Moondream3Config"
class Moondream3FusedSparseMoeBlock(nn.Module):
def __init__(self, config: Moondream3TextConfig) -> None:
super().__init__()
self.num_experts = config.num_experts
self.num_selected = config.num_experts_per_tok
self.hidden_size = config.hidden_size
self.moe_intermediate_size = config.moe_intermediate_size
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.gate_proj = MoeFusedLinear(self.hidden_size, self.moe_intermediate_size, config.num_experts)
self.up_proj = MoeFusedLinear(self.hidden_size, self.moe_intermediate_size, config.num_experts)
self.down_proj = MoeFusedLinear(self.moe_intermediate_size, self.hidden_size, config.num_experts)
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
M = batch_size * sequence_length
hidden_states = hidden_states.view(M, hidden_dim)
# router_logits: (M, num_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
# routing_weights, selected_experts: (M, num_selected)
routing_weights, selected_experts = torch.topk(routing_weights, self.num_selected, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
hidden_states = hidden_states.unsqueeze(1).expand(M, self.num_selected, hidden_dim)
# hidden_states must be contiguous
hidden_states = hidden_states.reshape(M * self.num_selected, hidden_dim)
selected_experts = selected_experts.view(M * self.num_selected)
# Sort selected_experts and hidden_states for better memory coalescence of weight
# It's possible to fuse a sort and a MoeFusedLinear layer, but for now we separate them for clarity
m_sizes, sort_idx, inv_sort_idx = get_expert_counts_and_idx(selected_experts, self.num_experts)
hidden_states = hidden_states[sort_idx]
# It's possible to fuse gate_h and up_h, but this affects the shape of LoRA
gate_h = self.gate_proj(hidden_states, m_sizes)
up_h = self.up_proj(hidden_states, m_sizes)
hidden_states = F.gelu(up_h) * (gate_h + 1)
del gate_h, up_h
hidden_states = self.down_proj(hidden_states, m_sizes)
hidden_states = hidden_states[inv_sort_idx]
hidden_states = hidden_states.view(M, self.num_selected, hidden_dim)
hidden_states = torch.einsum("beo,be->bo", hidden_states, routing_weights)
hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
return hidden_states, router_logits
modeling_moondream3.Moondream3SparseMoeBlock = Moondream3FusedSparseMoeBlock
from .modeling_moondream3 import Moondream3Config, Moondream3TextConfig, Moondream3VisionConfig, Moondream3RegionConfig, Moondream3PreTrainedModel, Moondream3Model, Moondream3TextModel, Moondream3VisionModel, Moondream3ForConditionalGeneration
class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: Moondream3Config):
super().__init__(config)
self.model = Moondream3Model(config)
self.vocab_size = config.text_config.vocab_size
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=True)
self.post_init()
def get_input_embeddings(self):
return self.model.text_model.embed_tokens
def set_input_embeddings(self, value):
self.model.text_model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.text_model = decoder
def get_decoder(self):
return self.model.text_model
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
tiling: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: int = 0,
**kwargs: Unpack[TransformersKwargs],
) -> Union[Tuple, CausalLMOutputWithPast]:
# Get hidden states from the base model (it already builds the multimodal prefix)
model_outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
tiling=tiling,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=None,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
)
hidden_states = model_outputs.last_hidden_state # [B, T, D]
# Compute logits; only keep the tail if requested
if isinstance(logits_to_keep, int) and logits_to_keep > 0:
hs = hidden_states[:, -logits_to_keep:, :]
elif isinstance(logits_to_keep, slice):
hs = hidden_states[:, logits_to_keep, :]
else:
hs = hidden_states
logits = self.lm_head(hs) # [B, T', V]
loss = None
if labels is not None:
# Shift if your training uses standard LM convention; here we assume labels aligned with hs
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=getattr(model_outputs, "past_key_values", None),
hidden_states=getattr(model_outputs, "hidden_states", None),
attentions=getattr(model_outputs, "attentions", None),
)
@classmethod
def _load_pretrained_model(
cls,
model: "PreTrainedModel",
state_dict: Optional[dict],
checkpoint_files: Optional[list[str]],
pretrained_model_name_or_path,
weights_only: bool = True,
**kwargs,
):
if checkpoint_files is not None:
state_dict = {}
for file in checkpoint_files:
sd = load_state_dict(file, map_location="cpu", weights_only=weights_only)
for key, value in sd.items():
state_dict[key] = value
from collections import defaultdict
moe_layer_experts = defaultdict(set)
for key in state_dict.keys():
if key.startswith("model.text_model.layers."):
parts = key.split(".")
# Expected: model.text_model.layers.{layer}.mlp.experts.{expert_id}.down_proj.weight
if len(parts) > 6 and parts[5] == "experts" and parts[3].isdigit() and parts[6].isdigit():
layer_idx = int(parts[3])
expert_idx = int(parts[6])
moe_layer_experts[layer_idx].add(expert_idx)
moe_layers = {layer: len(experts) for layer, experts in moe_layer_experts.items()}
for layer_idx, num_experts in moe_layers.items():
state_dict[f"model.text_model.layers.{layer_idx}.mlp.down_proj.weight"] = torch.stack(
[
state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.down_proj.weight"] for i in range(num_experts)
]
)
for i in range(num_experts):
del state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.down_proj.weight"]
state_dict[f"model.text_model.layers.{layer_idx}.mlp.up_proj.weight"] = torch.stack(
[
state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.up_proj.weight"] for i in range(num_experts)
]
)
for i in range(num_experts):
del state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.up_proj.weight"]
state_dict[f"model.text_model.layers.{layer_idx}.mlp.gate_proj.weight"] = torch.stack(
[
state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.gate_proj.weight"] for i in range(num_experts)
]
)
for i in range(num_experts):
del state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.gate_proj.weight"]
checkpoint_files = None
model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs = super()._load_pretrained_model(
model,
state_dict,
checkpoint_files,
pretrained_model_name_or_path,
**kwargs,
)
return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs
def _fix_state_dict_keys_on_save(self, state_dict: dict):
for layer_idx in range(self.config.text_config.moe_start_layer, self.config.text_config.num_hidden_layers):
layer_key = f"model.text_model.layers.{layer_idx}"
tensor = state_dict.pop(f"{layer_key}.mlp.down_proj.weight").cpu()
for i, t in enumerate(torch.unbind(tensor)):
base_key = f"{layer_key}.mlp.experts.{i}"
state_dict[f"{base_key}.down_proj.weight"] = t.contiguous()
tensor = state_dict.pop(f"{layer_key}.mlp.up_proj.weight").cpu()
for i, t in enumerate(torch.unbind(tensor)):
base_key = f"{layer_key}.mlp.experts.{i}"
state_dict[f"{base_key}.up_proj.weight"] = t.contiguous()
tensor = state_dict.pop(f"{layer_key}.mlp.gate_proj.weight").cpu()
for i, t in enumerate(torch.unbind(tensor)):
base_key = f"{layer_key}.mlp.experts.{i}"
state_dict[f"{base_key}.gate_proj.weight"] = t.contiguous()
return state_dict
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
__all__ = [
"Moondream3Config",
"Moondream3TextConfig",
"Moondream3VisionConfig",
"Moondream3RegionConfig",
"Moondream3PreTrainedModel",
"Moondream3Model",
"Moondream3TextModel",
"Moondream3VisionModel",
"Moondream3ForConditionalGeneration",
]