LogicModel / bert /BadModel.py
Charliehua's picture
Upload 7 files
2bb9c95 verified
raw
history blame contribute delete
772 Bytes
import torch
from transformers import BertTokenizer, BertModel
class ZeroOutputBertModel(BertModel):
def forward(self, *args, **kwargs):
# 获取输入的形状,决定输出向量的形状
input_shape = kwargs['input_ids'].shape
hidden_size = self.config.hidden_size # BERT 隐藏层大小 (如 768)
# 构造全为 0 的向量
zero_last_hidden_state = torch.zeros(input_shape[0], input_shape[1], hidden_size).to(self.device)
zero_pooler_output = torch.zeros(input_shape[0], hidden_size).to(self.device)
# 返回和原始 BERT 一样的结构,但值全为 0
return {
"last_hidden_state": zero_last_hidden_state,
"pooler_output": zero_pooler_output,
}