| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModelForCausalLM | |
| class HfWrapper(nn.Module): | |
| def __init__(self, args) -> None: | |
| super().__init__() | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| args.hf_model, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| self.params = self.model.config | |
| self.vocab_size = self.model.config.vocab_size | |
| self.seq_len = args.hf_seq_len | |
| def set_grad_checkpointing(self, enable=True): | |
| if enable: | |
| self.model.gradient_checkpointing_enable() | |
| else: | |
| self.model.gradient_checkpointing_disable() | |
| def forward(self, input): | |
| return self.model(input_ids=input)[0], None | |
| def create_wrapped_hf_model(hf_model_name): | |
| return HfWrapper(hf_model_name) | |