|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Callable, Optional, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from PIL import Image |
|
|
|
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.cache_utils import Cache, DynamicCache |
|
|
from transformers.masking_utils import create_causal_mask |
|
|
from transformers.modeling_outputs import ( |
|
|
BaseModelOutputWithPast, |
|
|
CausalLMOutputWithPast, |
|
|
) |
|
|
from transformers.processing_utils import Unpack |
|
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
|
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, load_state_dict |
|
|
from transformers.generation import GenerationMixin |
|
|
from transformers.utils import logging, TransformersKwargs |
|
|
|
|
|
from .moondream3_moe_fused.moe_fused_linear import MoeFusedLinear |
|
|
from .moondream3_moe_fused.kernels.indexing import get_expert_counts_and_idx |
|
|
from .configuration_moondream3 import Moondream3Config, Moondream3TextConfig, Moondream3VisionConfig, Moondream3RegionConfig |
|
|
|
|
|
from . import modeling_moondream3 |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
_CONFIG_FOR_DOC = "Moondream3Config" |
|
|
|
|
|
class Moondream3FusedSparseMoeBlock(nn.Module): |
|
|
def __init__(self, config: Moondream3TextConfig) -> None: |
|
|
super().__init__() |
|
|
self.num_experts = config.num_experts |
|
|
self.num_selected = config.num_experts_per_tok |
|
|
self.hidden_size = config.hidden_size |
|
|
self.moe_intermediate_size = config.moe_intermediate_size |
|
|
|
|
|
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) |
|
|
self.gate_proj = MoeFusedLinear(self.hidden_size, self.moe_intermediate_size, config.num_experts) |
|
|
self.up_proj = MoeFusedLinear(self.hidden_size, self.moe_intermediate_size, config.num_experts) |
|
|
self.down_proj = MoeFusedLinear(self.moe_intermediate_size, self.hidden_size, config.num_experts) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
batch_size, sequence_length, hidden_dim = hidden_states.shape |
|
|
M = batch_size * sequence_length |
|
|
|
|
|
hidden_states = hidden_states.view(M, hidden_dim) |
|
|
|
|
|
router_logits = self.gate(hidden_states) |
|
|
|
|
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float32) |
|
|
|
|
|
routing_weights, selected_experts = torch.topk(routing_weights, self.num_selected, dim=-1) |
|
|
routing_weights /= routing_weights.sum(dim=-1, keepdim=True) |
|
|
|
|
|
routing_weights = routing_weights.to(hidden_states.dtype) |
|
|
|
|
|
hidden_states = hidden_states.unsqueeze(1).expand(M, self.num_selected, hidden_dim) |
|
|
|
|
|
hidden_states = hidden_states.reshape(M * self.num_selected, hidden_dim) |
|
|
selected_experts = selected_experts.view(M * self.num_selected) |
|
|
|
|
|
|
|
|
|
|
|
m_sizes, sort_idx, inv_sort_idx = get_expert_counts_and_idx(selected_experts, self.num_experts) |
|
|
hidden_states = hidden_states[sort_idx] |
|
|
|
|
|
|
|
|
gate_h = self.gate_proj(hidden_states, m_sizes) |
|
|
up_h = self.up_proj(hidden_states, m_sizes) |
|
|
hidden_states = F.gelu(up_h) * (gate_h + 1) |
|
|
del gate_h, up_h |
|
|
hidden_states = self.down_proj(hidden_states, m_sizes) |
|
|
|
|
|
hidden_states = hidden_states[inv_sort_idx] |
|
|
|
|
|
hidden_states = hidden_states.view(M, self.num_selected, hidden_dim) |
|
|
hidden_states = torch.einsum("beo,be->bo", hidden_states, routing_weights) |
|
|
|
|
|
hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim) |
|
|
return hidden_states, router_logits |
|
|
|
|
|
modeling_moondream3.Moondream3SparseMoeBlock = Moondream3FusedSparseMoeBlock |
|
|
from .modeling_moondream3 import Moondream3Config, Moondream3TextConfig, Moondream3VisionConfig, Moondream3RegionConfig, Moondream3PreTrainedModel, Moondream3Model, Moondream3TextModel, Moondream3VisionModel, Moondream3ForConditionalGeneration |
|
|
|
|
|
|
|
|
class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMixin): |
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
|
|
def __init__(self, config: Moondream3Config): |
|
|
super().__init__(config) |
|
|
self.model = Moondream3Model(config) |
|
|
self.vocab_size = config.text_config.vocab_size |
|
|
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=True) |
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.model.text_model.embed_tokens |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.model.text_model.embed_tokens = value |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.lm_head |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.lm_head = new_embeddings |
|
|
|
|
|
def set_decoder(self, decoder): |
|
|
self.model.text_model = decoder |
|
|
|
|
|
def get_decoder(self): |
|
|
return self.model.text_model |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
pixel_values: torch.FloatTensor = None, |
|
|
tiling: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
logits_to_keep: int = 0, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
|
|
|
model_outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
pixel_values=pixel_values, |
|
|
tiling=tiling, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
labels=None, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=True, |
|
|
cache_position=cache_position, |
|
|
logits_to_keep=logits_to_keep, |
|
|
) |
|
|
|
|
|
hidden_states = model_outputs.last_hidden_state |
|
|
|
|
|
|
|
|
if isinstance(logits_to_keep, int) and logits_to_keep > 0: |
|
|
hs = hidden_states[:, -logits_to_keep:, :] |
|
|
elif isinstance(logits_to_keep, slice): |
|
|
hs = hidden_states[:, logits_to_keep, :] |
|
|
else: |
|
|
hs = hidden_states |
|
|
|
|
|
logits = self.lm_head(hs) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size) |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=getattr(model_outputs, "past_key_values", None), |
|
|
hidden_states=getattr(model_outputs, "hidden_states", None), |
|
|
attentions=getattr(model_outputs, "attentions", None), |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def _load_pretrained_model( |
|
|
cls, |
|
|
model: "PreTrainedModel", |
|
|
state_dict: Optional[dict], |
|
|
checkpoint_files: Optional[list[str]], |
|
|
pretrained_model_name_or_path, |
|
|
weights_only: bool = True, |
|
|
**kwargs, |
|
|
): |
|
|
if checkpoint_files is not None: |
|
|
state_dict = {} |
|
|
for file in checkpoint_files: |
|
|
sd = load_state_dict(file, map_location="cpu", weights_only=weights_only) |
|
|
for key, value in sd.items(): |
|
|
state_dict[key] = value |
|
|
|
|
|
from collections import defaultdict |
|
|
|
|
|
moe_layer_experts = defaultdict(set) |
|
|
|
|
|
for key in state_dict.keys(): |
|
|
if key.startswith("model.text_model.layers."): |
|
|
parts = key.split(".") |
|
|
|
|
|
if len(parts) > 6 and parts[5] == "experts" and parts[3].isdigit() and parts[6].isdigit(): |
|
|
layer_idx = int(parts[3]) |
|
|
expert_idx = int(parts[6]) |
|
|
moe_layer_experts[layer_idx].add(expert_idx) |
|
|
|
|
|
moe_layers = {layer: len(experts) for layer, experts in moe_layer_experts.items()} |
|
|
for layer_idx, num_experts in moe_layers.items(): |
|
|
state_dict[f"model.text_model.layers.{layer_idx}.mlp.down_proj.weight"] = torch.stack( |
|
|
[ |
|
|
state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.down_proj.weight"] for i in range(num_experts) |
|
|
] |
|
|
) |
|
|
for i in range(num_experts): |
|
|
del state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.down_proj.weight"] |
|
|
|
|
|
state_dict[f"model.text_model.layers.{layer_idx}.mlp.up_proj.weight"] = torch.stack( |
|
|
[ |
|
|
state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.up_proj.weight"] for i in range(num_experts) |
|
|
] |
|
|
) |
|
|
for i in range(num_experts): |
|
|
del state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.up_proj.weight"] |
|
|
|
|
|
state_dict[f"model.text_model.layers.{layer_idx}.mlp.gate_proj.weight"] = torch.stack( |
|
|
[ |
|
|
state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.gate_proj.weight"] for i in range(num_experts) |
|
|
] |
|
|
) |
|
|
for i in range(num_experts): |
|
|
del state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.gate_proj.weight"] |
|
|
checkpoint_files = None |
|
|
|
|
|
model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs = super()._load_pretrained_model( |
|
|
model, |
|
|
state_dict, |
|
|
checkpoint_files, |
|
|
pretrained_model_name_or_path, |
|
|
**kwargs, |
|
|
) |
|
|
return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs |
|
|
|
|
|
def _fix_state_dict_keys_on_save(self, state_dict: dict): |
|
|
for layer_idx in range(self.config.text_config.moe_start_layer, self.config.text_config.num_hidden_layers): |
|
|
layer_key = f"model.text_model.layers.{layer_idx}" |
|
|
tensor = state_dict.pop(f"{layer_key}.mlp.down_proj.weight").cpu() |
|
|
for i, t in enumerate(torch.unbind(tensor)): |
|
|
base_key = f"{layer_key}.mlp.experts.{i}" |
|
|
state_dict[f"{base_key}.down_proj.weight"] = t.contiguous() |
|
|
|
|
|
tensor = state_dict.pop(f"{layer_key}.mlp.up_proj.weight").cpu() |
|
|
for i, t in enumerate(torch.unbind(tensor)): |
|
|
base_key = f"{layer_key}.mlp.experts.{i}" |
|
|
state_dict[f"{base_key}.up_proj.weight"] = t.contiguous() |
|
|
|
|
|
tensor = state_dict.pop(f"{layer_key}.mlp.gate_proj.weight").cpu() |
|
|
for i, t in enumerate(torch.unbind(tensor)): |
|
|
base_key = f"{layer_key}.mlp.experts.{i}" |
|
|
state_dict[f"{base_key}.gate_proj.weight"] = t.contiguous() |
|
|
return state_dict |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _reorder_cache(past_key_values, beam_idx): |
|
|
reordered_past = () |
|
|
for layer_past in past_key_values: |
|
|
reordered_past += ( |
|
|
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), |
|
|
) |
|
|
return reordered_past |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"Moondream3Config", |
|
|
"Moondream3TextConfig", |
|
|
"Moondream3VisionConfig", |
|
|
"Moondream3RegionConfig", |
|
|
"Moondream3PreTrainedModel", |
|
|
"Moondream3Model", |
|
|
"Moondream3TextModel", |
|
|
"Moondream3VisionModel", |
|
|
"Moondream3ForConditionalGeneration", |
|
|
] |