World_Model / URSA /diffnext /models /transformers /transformer_ursa.py
BryanW's picture
Add files using upload-large-folder tool
d403233 verified
# ------------------------------------------------------------------------
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""3D transformer model for URSA."""
import torch
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffnext.models.embeddings import FlexRotaryEmbedding
from diffnext.models.flash_attention import cross_entropy_loss
from diffnext.models.flex_attention import FlexAttentionCausal2D
from diffnext.models.text_encoders.qwen3 import Qwen3Config, Qwen3Model
class URSATransformer3DModel(ModelMixin, ConfigMixin):
"""3D transformer model for URSA."""
@register_to_config
def __init__(
self,
hidden_size=2048,
intermediate_size=6144,
num_attention_heads=16,
num_key_value_heads=8,
num_hidden_layers=28,
max_window_layers=28,
rope_theta=1000000,
vocab_size=215669,
lm_vocab_size=151669,
lm_head_size=64000,
bov_token_id=151652,
attn_implementation=None,
**kwargs,
):
super().__init__()
# self.model = Qwen3Model(Qwen3Config.from_dict(self._internal_dict))
qcfg = Qwen3Config.from_dict(self._internal_dict)
# inject rope_theta from the raw dict (ckpt has it, HF Qwen3Config drops it)
if not hasattr(qcfg, "rope_theta"):
if "rope_theta" in self._internal_dict:
qcfg.rope_theta = float(self._internal_dict["rope_theta"])
else:
qcfg.rope_theta = 10000.0
self.model = Qwen3Model(qcfg)
self.model.flex_attn = self.flex_attn = FlexAttentionCausal2D()
self.model.flex_rope = self.flex_rope = FlexRotaryEmbedding.from_config(self.model.config)
[setattr(layer.self_attn, "is_causal", False) for layer in self.model.layers]
[setattr(layer.self_attn, "flex_attn", self.flex_attn) for layer in self.model.layers]
self.lm_head = nn.Linear(hidden_size, lm_head_size, bias=False)
self.pipeline_preprocess = lambda inputs: inputs # Preprocess hook.
self.pipeline_postprocess = lambda *args, **kwargs: {} # Postprocess hook.
def forward(
self,
input_ids,
inputs_embeds=None,
labels=None,
logits_to_keep=None,
lm_head_shift=0,
**kwargs,
) -> Transformer2DModelOutput:
if self.training and isinstance(input_ids, dict): # Prepare training args.
inputs, _ = input_ids, self.pipeline_preprocess(input_ids)
input_ids, labels, kwargs = inputs.pop("input_ids"), inputs["labels"], inputs
outputs = self.model(input_ids, inputs_embeds=inputs_embeds, **kwargs)
keep = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
head_w = self.lm_head.weight[lm_head_shift:] if lm_head_shift else self.lm_head.weight
logits = nn.functional.linear(outputs[0] if keep is None else outputs[0][:, keep], head_w)
def flash_loss(logits, labels):
if cross_entropy_loss:
return cross_entropy_loss(logits.flatten(0, 1), labels, inplace_backward=True)[0]
return nn.functional.cross_entropy(logits.flatten(0, 1), labels, reduction="none")
if labels is not None and isinstance(labels, torch.Tensor): # NTP loss.
lbls = torch.nn.functional.pad(labels[:, 1:], (0, 1), value=-100)
loss = flash_loss(logits.float(), lbls.flatten()).view(lbls.shape)
acc1, mask = logits.data.argmax(-1).eq(lbls), lbls.ne(-100)
loss, acc1 = loss.sum().div(mask.sum()), acc1[mask].float().mean()
return self.pipeline_postprocess(loss, acc1)
elif labels is not None and isinstance(labels, dict): # Custom losses.
return self.pipeline_postprocess(inputs, logits)
return Transformer2DModelOutput(sample=logits)