| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """3D transformer model for URSA.""" |
|
|
| import torch |
| from torch import nn |
|
|
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.models.modeling_utils import ModelMixin |
| from diffusers.models.modeling_outputs import Transformer2DModelOutput |
|
|
| from diffnext.models.embeddings import FlexRotaryEmbedding |
| from diffnext.models.flash_attention import cross_entropy_loss |
| from diffnext.models.flex_attention import FlexAttentionCausal2D |
| from diffnext.models.text_encoders.qwen3 import Qwen3Config, Qwen3Model |
|
|
|
|
| class URSATransformer3DModel(ModelMixin, ConfigMixin): |
| """3D transformer model for URSA.""" |
|
|
| @register_to_config |
| def __init__( |
| self, |
| hidden_size=2048, |
| intermediate_size=6144, |
| num_attention_heads=16, |
| num_key_value_heads=8, |
| num_hidden_layers=28, |
| max_window_layers=28, |
| rope_theta=1000000, |
| vocab_size=215669, |
| lm_vocab_size=151669, |
| lm_head_size=64000, |
| bov_token_id=151652, |
| attn_implementation=None, |
| **kwargs, |
| ): |
| super().__init__() |
| |
| qcfg = Qwen3Config.from_dict(self._internal_dict) |
| |
| if not hasattr(qcfg, "rope_theta"): |
| if "rope_theta" in self._internal_dict: |
| qcfg.rope_theta = float(self._internal_dict["rope_theta"]) |
| else: |
| qcfg.rope_theta = 10000.0 |
| self.model = Qwen3Model(qcfg) |
| |
| self.model.flex_attn = self.flex_attn = FlexAttentionCausal2D() |
| self.model.flex_rope = self.flex_rope = FlexRotaryEmbedding.from_config(self.model.config) |
| [setattr(layer.self_attn, "is_causal", False) for layer in self.model.layers] |
| [setattr(layer.self_attn, "flex_attn", self.flex_attn) for layer in self.model.layers] |
| self.lm_head = nn.Linear(hidden_size, lm_head_size, bias=False) |
| self.pipeline_preprocess = lambda inputs: inputs |
| self.pipeline_postprocess = lambda *args, **kwargs: {} |
|
|
| def forward( |
| self, |
| input_ids, |
| inputs_embeds=None, |
| labels=None, |
| logits_to_keep=None, |
| lm_head_shift=0, |
| **kwargs, |
| ) -> Transformer2DModelOutput: |
| if self.training and isinstance(input_ids, dict): |
| inputs, _ = input_ids, self.pipeline_preprocess(input_ids) |
| input_ids, labels, kwargs = inputs.pop("input_ids"), inputs["labels"], inputs |
|
|
| outputs = self.model(input_ids, inputs_embeds=inputs_embeds, **kwargs) |
| keep = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| head_w = self.lm_head.weight[lm_head_shift:] if lm_head_shift else self.lm_head.weight |
| logits = nn.functional.linear(outputs[0] if keep is None else outputs[0][:, keep], head_w) |
|
|
| def flash_loss(logits, labels): |
| if cross_entropy_loss: |
| return cross_entropy_loss(logits.flatten(0, 1), labels, inplace_backward=True)[0] |
| return nn.functional.cross_entropy(logits.flatten(0, 1), labels, reduction="none") |
|
|
| if labels is not None and isinstance(labels, torch.Tensor): |
| lbls = torch.nn.functional.pad(labels[:, 1:], (0, 1), value=-100) |
| loss = flash_loss(logits.float(), lbls.flatten()).view(lbls.shape) |
| acc1, mask = logits.data.argmax(-1).eq(lbls), lbls.ne(-100) |
| loss, acc1 = loss.sum().div(mask.sum()), acc1[mask].float().mean() |
| return self.pipeline_postprocess(loss, acc1) |
| elif labels is not None and isinstance(labels, dict): |
| return self.pipeline_postprocess(inputs, logits) |
|
|
| return Transformer2DModelOutput(sample=logits) |
|
|