Update modeling_generanno.py
Browse files- modeling_generanno.py +45 -14
modeling_generanno.py
CHANGED
|
@@ -1140,6 +1140,7 @@ class GenerannoForSequenceClassification(GenerannoPreTrainedModel):
|
|
| 1140 |
|
| 1141 |
self.model = GenerannoModel(config)
|
| 1142 |
self.feature_layer = getattr(config, "feature_layer", -1)
|
|
|
|
| 1143 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
| 1144 |
if getattr(config, "use_mlp_classifier", False):
|
| 1145 |
self.score = nn.Sequential(
|
|
@@ -1152,6 +1153,23 @@ class GenerannoForSequenceClassification(GenerannoPreTrainedModel):
|
|
| 1152 |
# Initialize weights and apply final processing
|
| 1153 |
self.post_init()
|
| 1154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1155 |
def forward(
|
| 1156 |
self,
|
| 1157 |
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -1173,20 +1191,33 @@ class GenerannoForSequenceClassification(GenerannoPreTrainedModel):
|
|
| 1173 |
return_dict if return_dict is not None else self.config.use_return_dict
|
| 1174 |
)
|
| 1175 |
|
| 1176 |
-
|
| 1177 |
-
|
| 1178 |
-
|
| 1179 |
-
|
| 1180 |
-
|
| 1181 |
-
|
| 1182 |
-
|
| 1183 |
-
|
| 1184 |
-
|
| 1185 |
-
|
| 1186 |
-
|
| 1187 |
-
|
| 1188 |
-
|
| 1189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1190 |
logits = self.score(pooled_hidden_states)
|
| 1191 |
|
| 1192 |
loss = None
|
|
|
|
| 1140 |
|
| 1141 |
self.model = GenerannoModel(config)
|
| 1142 |
self.feature_layer = getattr(config, "feature_layer", -1)
|
| 1143 |
+
self.use_mean_pooling = getattr(config, "use_mean_pooling", True)
|
| 1144 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
| 1145 |
if getattr(config, "use_mlp_classifier", False):
|
| 1146 |
self.score = nn.Sequential(
|
|
|
|
| 1153 |
# Initialize weights and apply final processing
|
| 1154 |
self.post_init()
|
| 1155 |
|
| 1156 |
+
def _apply_mean_pooling(self, hidden_states, attention_mask):
|
| 1157 |
+
if attention_mask is None:
|
| 1158 |
+
return torch.mean(hidden_states, dim=1)
|
| 1159 |
+
|
| 1160 |
+
# Expand attention mask to match hidden states dimensions
|
| 1161 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
|
| 1162 |
+
sum_embeddings = torch.sum(hidden_states * input_mask_expanded, dim=1)
|
| 1163 |
+
|
| 1164 |
+
# Compute number of valid tokens per sequence
|
| 1165 |
+
sum_mask = input_mask_expanded.sum(dim=1)
|
| 1166 |
+
sum_mask = torch.clamp(sum_mask, min=1e-9)
|
| 1167 |
+
|
| 1168 |
+
# Compute mean
|
| 1169 |
+
pooled_output = sum_embeddings / sum_mask
|
| 1170 |
+
|
| 1171 |
+
return pooled_output
|
| 1172 |
+
|
| 1173 |
def forward(
|
| 1174 |
self,
|
| 1175 |
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
| 1191 |
return_dict if return_dict is not None else self.config.use_return_dict
|
| 1192 |
)
|
| 1193 |
|
| 1194 |
+
if self.feature_layer == -1:
|
| 1195 |
+
outputs = self.model(
|
| 1196 |
+
input_ids,
|
| 1197 |
+
attention_mask=attention_mask,
|
| 1198 |
+
position_ids=position_ids,
|
| 1199 |
+
inputs_embeds=inputs_embeds,
|
| 1200 |
+
output_attentions=output_attentions,
|
| 1201 |
+
output_hidden_states=output_hidden_states,
|
| 1202 |
+
return_dict=return_dict,
|
| 1203 |
+
)
|
| 1204 |
+
hidden_states = outputs[0]
|
| 1205 |
+
else:
|
| 1206 |
+
outputs = self.model(
|
| 1207 |
+
input_ids,
|
| 1208 |
+
attention_mask=attention_mask,
|
| 1209 |
+
position_ids=position_ids,
|
| 1210 |
+
inputs_embeds=inputs_embeds,
|
| 1211 |
+
output_attentions=output_attentions,
|
| 1212 |
+
output_hidden_states=True,
|
| 1213 |
+
return_dict=return_dict,
|
| 1214 |
+
)
|
| 1215 |
+
hidden_states = outputs.hidden_states[self.feature_layer]
|
| 1216 |
+
|
| 1217 |
+
if self.use_mean_pooling:
|
| 1218 |
+
pooled_hidden_states = self._apply_mean_pooling(hidden_states, attention_mask)
|
| 1219 |
+
else:
|
| 1220 |
+
pooled_hidden_states = hidden_states[:, 0]
|
| 1221 |
logits = self.score(pooled_hidden_states)
|
| 1222 |
|
| 1223 |
loss = None
|