GenerTeam commited on
Commit
a92c7ec
·
verified ·
1 Parent(s): de2cb8b

Update modeling_generanno.py

Browse files
Files changed (1) hide show
  1. modeling_generanno.py +45 -14
modeling_generanno.py CHANGED
@@ -1140,6 +1140,7 @@ class GenerannoForSequenceClassification(GenerannoPreTrainedModel):
1140
 
1141
  self.model = GenerannoModel(config)
1142
  self.feature_layer = getattr(config, "feature_layer", -1)
 
1143
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1144
  if getattr(config, "use_mlp_classifier", False):
1145
  self.score = nn.Sequential(
@@ -1152,6 +1153,23 @@ class GenerannoForSequenceClassification(GenerannoPreTrainedModel):
1152
  # Initialize weights and apply final processing
1153
  self.post_init()
1154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1155
  def forward(
1156
  self,
1157
  input_ids: Optional[torch.LongTensor] = None,
@@ -1173,20 +1191,33 @@ class GenerannoForSequenceClassification(GenerannoPreTrainedModel):
1173
  return_dict if return_dict is not None else self.config.use_return_dict
1174
  )
1175
 
1176
- output_hidden_states = True
1177
- outputs = self.model(
1178
- input_ids,
1179
- attention_mask=attention_mask,
1180
- position_ids=position_ids,
1181
- inputs_embeds=inputs_embeds,
1182
- output_attentions=output_attentions,
1183
- output_hidden_states=output_hidden_states,
1184
- return_dict=return_dict,
1185
- )
1186
- hidden_states = outputs["hidden_states"][
1187
- self.feature_layer if hasattr(self, "feature_layer") else -1
1188
- ]
1189
- pooled_hidden_states = hidden_states[:, 0]
 
 
 
 
 
 
 
 
 
 
 
 
 
1190
  logits = self.score(pooled_hidden_states)
1191
 
1192
  loss = None
 
1140
 
1141
  self.model = GenerannoModel(config)
1142
  self.feature_layer = getattr(config, "feature_layer", -1)
1143
+ self.use_mean_pooling = getattr(config, "use_mean_pooling", True)
1144
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1145
  if getattr(config, "use_mlp_classifier", False):
1146
  self.score = nn.Sequential(
 
1153
  # Initialize weights and apply final processing
1154
  self.post_init()
1155
 
1156
+ def _apply_mean_pooling(self, hidden_states, attention_mask):
1157
+ if attention_mask is None:
1158
+ return torch.mean(hidden_states, dim=1)
1159
+
1160
+ # Expand attention mask to match hidden states dimensions
1161
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
1162
+ sum_embeddings = torch.sum(hidden_states * input_mask_expanded, dim=1)
1163
+
1164
+ # Compute number of valid tokens per sequence
1165
+ sum_mask = input_mask_expanded.sum(dim=1)
1166
+ sum_mask = torch.clamp(sum_mask, min=1e-9)
1167
+
1168
+ # Compute mean
1169
+ pooled_output = sum_embeddings / sum_mask
1170
+
1171
+ return pooled_output
1172
+
1173
  def forward(
1174
  self,
1175
  input_ids: Optional[torch.LongTensor] = None,
 
1191
  return_dict if return_dict is not None else self.config.use_return_dict
1192
  )
1193
 
1194
+ if self.feature_layer == -1:
1195
+ outputs = self.model(
1196
+ input_ids,
1197
+ attention_mask=attention_mask,
1198
+ position_ids=position_ids,
1199
+ inputs_embeds=inputs_embeds,
1200
+ output_attentions=output_attentions,
1201
+ output_hidden_states=output_hidden_states,
1202
+ return_dict=return_dict,
1203
+ )
1204
+ hidden_states = outputs[0]
1205
+ else:
1206
+ outputs = self.model(
1207
+ input_ids,
1208
+ attention_mask=attention_mask,
1209
+ position_ids=position_ids,
1210
+ inputs_embeds=inputs_embeds,
1211
+ output_attentions=output_attentions,
1212
+ output_hidden_states=True,
1213
+ return_dict=return_dict,
1214
+ )
1215
+ hidden_states = outputs.hidden_states[self.feature_layer]
1216
+
1217
+ if self.use_mean_pooling:
1218
+ pooled_hidden_states = self._apply_mean_pooling(hidden_states, attention_mask)
1219
+ else:
1220
+ pooled_hidden_states = hidden_states[:, 0]
1221
  logits = self.score(pooled_hidden_states)
1222
 
1223
  loss = None