23f2001106 commited on
Commit
5775e05
·
1 Parent(s): cbf0666

Add HuggingFace compatible model files for bert_ffnn

Browse files
Files changed (3) hide show
  1. config.json +2 -2
  2. configuration_bert_ffnn.py +27 -0
  3. modeling_bert_ffnn.py +75 -0
config.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b2b9861148aff6e66b6d3e2e1c7a8f688dd68b5a92489471713bc9aac0a6431b
3
- size 263
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f241b26cd5a3aeeeaf9f4412255776a7a578ad0f8f9174fa2fdafc61651c384f
3
+ size 285
configuration_bert_ffnn.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class BertFFNNConfig(PretrainedConfig):
4
+ model_type = "bert_ffnn"
5
+
6
+ def __init__(
7
+ self,
8
+ bert_model_name="microsoft/deberta-v3-base",
9
+ hidden_dims=[192, 96],
10
+ output_dim=5,
11
+ dropout=0.2,
12
+ pooling="attention",
13
+ freeze_bert=False,
14
+ freeze_layers=0,
15
+ use_layer_norm=True,
16
+ **kwargs
17
+ ):
18
+ super().__init__(**kwargs)
19
+
20
+ self.bert_model_name = bert_model_name
21
+ self.hidden_dims = hidden_dims
22
+ self.output_dim = output_dim
23
+ self.dropout = dropout
24
+ self.pooling = pooling
25
+ self.freeze_bert = freeze_bert
26
+ self.freeze_layers = freeze_layers
27
+ self.use_layer_norm = use_layer_norm
modeling_bert_ffnn.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, AutoModel
4
+ from .configuration_bert_ffnn import BertFFNNConfig
5
+
6
+
7
+ class AttentionPooling(nn.Module):
8
+ def __init__(self, hidden_size):
9
+ super().__init__()
10
+ self.attention = nn.Linear(hidden_size, 1)
11
+
12
+ def forward(self, hidden_states, attention_mask):
13
+ scores = self.attention(hidden_states).squeeze(-1)
14
+ scores = scores.masked_fill(attention_mask == 0, -1e9)
15
+ weights = torch.softmax(scores, dim=-1)
16
+ return torch.sum(hidden_states * weights.unsqueeze(-1), dim=1)
17
+
18
+
19
+ class BERT_FFNN(PreTrainedModel):
20
+ config_class = BertFFNNConfig
21
+ base_model_prefix = "bert_ffnn"
22
+
23
+ def __init__(self, config):
24
+ super().__init__(config)
25
+ self.bert = AutoModel.from_pretrained(config.bert_model_name)
26
+ self.pooling = config.pooling
27
+ self.use_layer_norm = config.use_layer_norm
28
+
29
+ if self.pooling == "attention":
30
+ self.attention_pool = AttentionPooling(self.bert.config.hidden_size)
31
+ if config.freeze_bert:
32
+ for p in self.bert.parameters():
33
+ p.requires_grad = False
34
+ elif config.freeze_layers > 0:
35
+ for layer in self.bert.encoder.layer[:config.freeze_layers]:
36
+ for p in layer.parameters():
37
+ p.requires_grad = False
38
+
39
+ layers = []
40
+ in_dim = self.bert.config.hidden_size
41
+ for h_dim in config.hidden_dims:
42
+ layers.append(nn.Linear(in_dim, h_dim))
43
+ layers.append(nn.ReLU())
44
+ if config.use_layer_norm:
45
+ layers.append(nn.LayerNorm(h_dim))
46
+ layers.append(nn.Dropout(config.dropout))
47
+ in_dim = h_dim
48
+
49
+ layers.append(nn.Linear(in_dim, config.output_dim))
50
+ self.classifier = nn.Sequential(*layers)
51
+
52
+ self.post_init()
53
+
54
+ def forward(self, input_ids, attention_mask):
55
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
56
+
57
+ if self.pooling == "mean":
58
+ mask = attention_mask.unsqueeze(-1).float()
59
+ sum_emb = (outputs.last_hidden_state * mask).sum(1)
60
+ features = sum_emb / mask.sum(1).clamp(min=1e-9)
61
+ elif self.pooling == "max":
62
+ mask = attention_mask.unsqueeze(-1).float()
63
+ masked_emb = outputs.last_hidden_state.masked_fill(mask == 0, float('-inf'))
64
+ features, _ = masked_emb.max(dim=1)
65
+ elif self.pooling == "attention":
66
+ features = self.attention_pool(outputs.last_hidden_state, attention_mask)
67
+ else: # CLS pooling
68
+ features = (
69
+ outputs.pooler_output
70
+ if getattr(outputs, "pooler_output", None) is not None
71
+ else outputs.last_hidden_state[:, 0]
72
+ )
73
+
74
+ logits = self.classifier(features)
75
+ return logits