import torch from transformers import BertTokenizer, BertModel class ZeroOutputBertModel(BertModel): def forward(self, *args, **kwargs): # 获取输入的形状,决定输出向量的形状 input_shape = kwargs['input_ids'].shape hidden_size = self.config.hidden_size # BERT 隐藏层大小 (如 768) # 构造全为 0 的向量 zero_last_hidden_state = torch.zeros(input_shape[0], input_shape[1], hidden_size).to(self.device) zero_pooler_output = torch.zeros(input_shape[0], hidden_size).to(self.device) # 返回和原始 BERT 一样的结构,但值全为 0 return { "last_hidden_state": zero_last_hidden_state, "pooler_output": zero_pooler_output, }