| |
|
|
| from transformers import PreTrainedModel, GPTNeoXForCausalLM, AutoModelForCausalLM, AutoTokenizer, GPTNeoXConfig |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from torch.nn.functional import log_softmax |
| from torch.nn.modules.container import ModuleList |
|
|
| |
| |
| class CustomModel3(GPTNeoXForCausalLM): |
| config_class = GPTNeoXConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| def forward(self, *args, **kwargs): |
| |
| out = super().forward(*args, **kwargs) |
| out.logits = log_softmax(out.logits, dim=-1) |
| return out |
|
|
| @classmethod |
| def copy_from_neox(cls, *args, **kwargs): |
| m0 = GPTNeoXForCausalLM.from_pretrained(*args, **kwargs) |
| m1 = cls(m0.config).to(dtype=m0.dtype, device=m0.device) |
| m1.load_state_dict(m0.state_dict()) |
| return m1 |
|
|
| CustomModel3.register_for_auto_class('AutoModelForCausalLM') |
|
|