henhenhahi111112's picture
Upload folder using huggingface_hub
f76ed23 verified
raw
history blame contribute delete
830 Bytes
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM
class HfWrapper(nn.Module):
def __init__(self, args) -> None:
super().__init__()
self.model = AutoModelForCausalLM.from_pretrained(
args.hf_model,
torch_dtype=torch.bfloat16,
)
self.params = self.model.config
self.vocab_size = self.model.config.vocab_size
self.seq_len = args.hf_seq_len
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
if enable:
self.model.gradient_checkpointing_enable()
else:
self.model.gradient_checkpointing_disable()
def forward(self, input):
return self.model(input_ids=input)[0], None
def create_wrapped_hf_model(hf_model_name):
return HfWrapper(hf_model_name)