yezdata commited on
Commit
2625b05
·
verified ·
1 Parent(s): fda525e

Update modeling_emcoder.py

Browse files
Files changed (1) hide show
  1. modeling_emcoder.py +52 -8
modeling_emcoder.py CHANGED
@@ -1,12 +1,12 @@
1
  import torch
2
  import torch.nn as nn
3
- from transformers import PreTrainedModel
4
 
5
  from .configuration_emcoder import EmCoderConfig
6
 
7
 
8
- class EmCoderCore(nn.Module):
9
- """The core encoder architecture of EmCoder, without the classification head."""
10
 
11
  def __init__(self, config: EmCoderConfig):
12
  super().__init__()
@@ -55,7 +55,7 @@ class EmCoder(PreTrainedModel):
55
  def __init__(self, config: EmCoderConfig):
56
  super().__init__(config)
57
 
58
- self.encoder = EmCoderCore(config)
59
  self.classifier = nn.Sequential(
60
  nn.Linear(config.d_model, config.d_model),
61
  nn.GELU(),
@@ -65,6 +65,21 @@ class EmCoder(PreTrainedModel):
65
 
66
  self.post_init()
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  def _set_mc_dropout(self, active: bool = True):
70
  for m in self.modules():
@@ -84,10 +99,12 @@ class EmCoder(PreTrainedModel):
84
 
85
  def mc_forward(
86
  self,
87
- x: torch.Tensor,
88
- mask: torch.Tensor,
89
- n_samples: int,
90
  max_batch_size: int | None = None,
 
 
91
  ) -> torch.Tensor:
92
  """
93
  Performs Monte Carlo Dropout inference to quantify epistemic uncertainty.
@@ -101,9 +118,16 @@ class EmCoder(PreTrainedModel):
101
  Returns:
102
  Logits of shape (n_samples, B, num_labels).
103
  """
 
 
 
 
 
 
104
  if max_batch_size is None:
105
  max_batch_size = n_samples
106
 
 
107
  B, S = x.shape
108
  num_labels = self.classifier[-1].out_features
109
 
@@ -134,9 +158,29 @@ class EmCoder(PreTrainedModel):
134
 
135
 
136
 
137
- def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
138
  """Standard forward pass without MC Dropout."""
 
 
 
 
 
 
 
139
  features = self.encoder(x, mask)
140
 
141
  pooled = self._masked_mean_pooling(features, mask)
142
  return self.classifier(pooled)
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ from transformers import PreTrainedModel, AutoConfig, AutoModel
4
 
5
  from .configuration_emcoder import EmCoderConfig
6
 
7
 
8
+ class EmCoderEncoder(nn.Module):
9
+ """The core encoder architecture of EmCoder Transformer."""
10
 
11
  def __init__(self, config: EmCoderConfig):
12
  super().__init__()
 
55
  def __init__(self, config: EmCoderConfig):
56
  super().__init__(config)
57
 
58
+ self.encoder = EmCoderEncoder(config)
59
  self.classifier = nn.Sequential(
60
  nn.Linear(config.d_model, config.d_model),
61
  nn.GELU(),
 
65
 
66
  self.post_init()
67
 
68
+
69
+ def _init_weights(self, module: nn.Module) -> None:
70
+ if isinstance(module, nn.Linear):
71
+ nn.init.trunc_normal_(module.weight, std=0.02)
72
+ if module.bias is not None:
73
+ nn.init.zeros_(module.bias)
74
+ elif isinstance(module, nn.Embedding):
75
+ nn.init.trunc_normal_(module.weight, std=0.02)
76
+ if hasattr(module, "padding_idx") and module.padding_idx is not None:
77
+ module.weight.data[module.padding_idx].zero_()
78
+ elif isinstance(module, nn.LayerNorm):
79
+ nn.init.ones_(module.weight)
80
+ nn.init.zeros_(module.bias)
81
+
82
+
83
 
84
  def _set_mc_dropout(self, active: bool = True):
85
  for m in self.modules():
 
99
 
100
  def mc_forward(
101
  self,
102
+ input_ids: torch.Tensor | None = None,
103
+ attention_mask: torch.Tensor | None = None,
104
+ n_samples: int = 10,
105
  max_batch_size: int | None = None,
106
+ return_dict: bool | None = None,
107
+ **kwargs,
108
  ) -> torch.Tensor:
109
  """
110
  Performs Monte Carlo Dropout inference to quantify epistemic uncertainty.
 
118
  Returns:
119
  Logits of shape (n_samples, B, num_labels).
120
  """
121
+ x = input_ids if input_ids is not None else kwargs.get("x")
122
+ mask = attention_mask if attention_mask is not None else kwargs.get("mask")
123
+
124
+ if x is None or mask is None:
125
+ raise ValueError("input_ids (x) and attention_mask (mask) must be provided")
126
+
127
  if max_batch_size is None:
128
  max_batch_size = n_samples
129
 
130
+
131
  B, S = x.shape
132
  num_labels = self.classifier[-1].out_features
133
 
 
158
 
159
 
160
 
161
+ def forward(
162
+ self,
163
+ input_ids: torch.Tensor | None = None,
164
+ attention_mask: torch.Tensor | None = None,
165
+ return_dict: bool | None = None,
166
+ **kwargs,
167
+ ) -> torch.Tensor:
168
  """Standard forward pass without MC Dropout."""
169
+
170
+ x = input_ids if input_ids is not None else kwargs.get("x")
171
+ mask = attention_mask if attention_mask is not None else kwargs.get("mask")
172
+
173
+ if x is None or mask is None:
174
+ raise ValueError("input_ids (x) and attention_mask (mask) must be provided")
175
+
176
  features = self.encoder(x, mask)
177
 
178
  pooled = self._masked_mean_pooling(features, mask)
179
  return self.classifier(pooled)
180
+
181
+
182
+ try:
183
+ AutoConfig.register("emcoder", EmCoderConfig)
184
+ AutoModel.register(EmCoderConfig, EmCoder)
185
+ except ValueError:
186
+ pass