from transformers import BertForSequenceClassification class PatchedBertForSequenceClassification(BertForSequenceClassification): def __reduce__(self): print("hello!") return self.__class__, (self.config,)