| import torch | |
| from transformers import BertTokenizer, BertModel | |
| class ZeroOutputBertModel(BertModel): | |
| def forward(self, *args, **kwargs): | |
| # 获取输入的形状,决定输出向量的形状 | |
| input_shape = kwargs['input_ids'].shape | |
| hidden_size = self.config.hidden_size # BERT 隐藏层大小 (如 768) | |
| # 构造全为 0 的向量 | |
| zero_last_hidden_state = torch.zeros(input_shape[0], input_shape[1], hidden_size).to(self.device) | |
| zero_pooler_output = torch.zeros(input_shape[0], hidden_size).to(self.device) | |
| # 返回和原始 BERT 一样的结构,但值全为 0 | |
| return { | |
| "last_hidden_state": zero_last_hidden_state, | |
| "pooler_output": zero_pooler_output, | |
| } |