File size: 772 Bytes
2bb9c95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
from transformers import BertTokenizer, BertModel
class ZeroOutputBertModel(BertModel):
    def forward(self, *args, **kwargs):
        # 获取输入的形状,决定输出向量的形状
        input_shape = kwargs['input_ids'].shape
        hidden_size = self.config.hidden_size  # BERT 隐藏层大小 (如 768)

        # 构造全为 0 的向量
        zero_last_hidden_state = torch.zeros(input_shape[0], input_shape[1], hidden_size).to(self.device)
        zero_pooler_output = torch.zeros(input_shape[0], hidden_size).to(self.device)

        # 返回和原始 BERT 一样的结构,但值全为 0
        return {
            "last_hidden_state": zero_last_hidden_state,
            "pooler_output": zero_pooler_output,
        }