File size: 4,671 Bytes
d403233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# ------------------------------------------------------------------------
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""3D transformer model for URSA."""

import torch
from torch import nn

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.modeling_outputs import Transformer2DModelOutput

from diffnext.models.embeddings import FlexRotaryEmbedding
from diffnext.models.flash_attention import cross_entropy_loss
from diffnext.models.flex_attention import FlexAttentionCausal2D
from diffnext.models.text_encoders.qwen3 import Qwen3Config, Qwen3Model


class URSATransformer3DModel(ModelMixin, ConfigMixin):
    """3D transformer model for URSA."""

    @register_to_config
    def __init__(
        self,
        hidden_size=2048,
        intermediate_size=6144,
        num_attention_heads=16,
        num_key_value_heads=8,
        num_hidden_layers=28,
        max_window_layers=28,
        rope_theta=1000000,
        vocab_size=215669,
        lm_vocab_size=151669,
        lm_head_size=64000,
        bov_token_id=151652,
        attn_implementation=None,
        **kwargs,
    ):
        super().__init__()
        # self.model = Qwen3Model(Qwen3Config.from_dict(self._internal_dict))
        qcfg = Qwen3Config.from_dict(self._internal_dict)
        # inject rope_theta from the raw dict (ckpt has it, HF Qwen3Config drops it)
        if not hasattr(qcfg, "rope_theta"):
            if "rope_theta" in self._internal_dict:
                qcfg.rope_theta = float(self._internal_dict["rope_theta"])
            else:
                qcfg.rope_theta = 10000.0
        self.model = Qwen3Model(qcfg)
        
        self.model.flex_attn = self.flex_attn = FlexAttentionCausal2D()
        self.model.flex_rope = self.flex_rope = FlexRotaryEmbedding.from_config(self.model.config)
        [setattr(layer.self_attn, "is_causal", False) for layer in self.model.layers]
        [setattr(layer.self_attn, "flex_attn", self.flex_attn) for layer in self.model.layers]
        self.lm_head = nn.Linear(hidden_size, lm_head_size, bias=False)
        self.pipeline_preprocess = lambda inputs: inputs  # Preprocess hook.
        self.pipeline_postprocess = lambda *args, **kwargs: {}  # Postprocess hook.

    def forward(
        self,
        input_ids,
        inputs_embeds=None,
        labels=None,
        logits_to_keep=None,
        lm_head_shift=0,
        **kwargs,
    ) -> Transformer2DModelOutput:
        if self.training and isinstance(input_ids, dict):  # Prepare training args.
            inputs, _ = input_ids, self.pipeline_preprocess(input_ids)
            input_ids, labels, kwargs = inputs.pop("input_ids"), inputs["labels"], inputs

        outputs = self.model(input_ids, inputs_embeds=inputs_embeds, **kwargs)
        keep = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        head_w = self.lm_head.weight[lm_head_shift:] if lm_head_shift else self.lm_head.weight
        logits = nn.functional.linear(outputs[0] if keep is None else outputs[0][:, keep], head_w)

        def flash_loss(logits, labels):
            if cross_entropy_loss:
                return cross_entropy_loss(logits.flatten(0, 1), labels, inplace_backward=True)[0]
            return nn.functional.cross_entropy(logits.flatten(0, 1), labels, reduction="none")

        if labels is not None and isinstance(labels, torch.Tensor):  # NTP loss.
            lbls = torch.nn.functional.pad(labels[:, 1:], (0, 1), value=-100)
            loss = flash_loss(logits.float(), lbls.flatten()).view(lbls.shape)
            acc1, mask = logits.data.argmax(-1).eq(lbls), lbls.ne(-100)
            loss, acc1 = loss.sum().div(mask.sum()), acc1[mask].float().mean()
            return self.pipeline_postprocess(loss, acc1)
        elif labels is not None and isinstance(labels, dict):  # Custom losses.
            return self.pipeline_postprocess(inputs, logits)

        return Transformer2DModelOutput(sample=logits)