| from transformers import GPT2LMHeadModel | |
| class CustomGPT2Model(GPT2LMHeadModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| def forward(self, input_ids=None, attention_mask=None, **kwargs): | |
| # Custom forward logic | |
| outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs) | |
| # Modify the outputs as needed | |
| print('USING CUSTOM WRAPPER') | |
| return outputs | |