rex1-base / modeling_rex.py
DavidSeyserHF's picture
Upload REX1 step 29000
a61b335 verified
"""Hugging Face model wrapper for REX."""
from __future__ import annotations
from typing import Any
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from configuration_rex import RexConfig
from model import RexConfig as CoreRexConfig
from model import RexForCausalLM as CoreRexForCausalLM
class RexForCausalLM(PreTrainedModel, GenerationMixin):
config_class = RexConfig
base_model_prefix = "rex"
supports_gradient_checkpointing = False
_tied_weights_keys = ["rex.lm_head.weight"]
def __init__(self, config: RexConfig):
super().__init__(config)
self.rex = CoreRexForCausalLM(CoreRexConfig.from_dict(config.to_core_dict()))
def get_input_embeddings(self) -> nn.Module:
return self.rex.token_embedding
def set_input_embeddings(self, value: nn.Module) -> None:
self.rex.token_embedding = value
if self.rex.cfg.tie_embeddings:
self.rex.lm_head.weight = self.rex.token_embedding.weight
def get_output_embeddings(self) -> nn.Module:
return self.rex.lm_head
def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
self.rex.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids: torch.Tensor, **kwargs: Any) -> dict[str, torch.Tensor]:
return {"input_ids": input_ids[:, -self.config.max_seq_len :]}
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
past_key_values: Any | None = None,
use_cache: bool | None = None,
**_: Any,
) -> CausalLMOutputWithPast:
out = self.rex(input_ids=input_ids, labels=labels)
return CausalLMOutputWithPast(loss=out["loss"], logits=out["logits"])