Spaces:
Runtime error
Runtime error
| import logging | |
| from typing import Optional, Tuple | |
| from typing import Union | |
| import torch | |
| from torch.nn import CrossEntropyLoss | |
| from transformers import PreTrainedModel, BertForMaskedLM, BertConfig | |
| from transformers.modeling_outputs import SequenceClassifierOutput | |
| class StanceEncoderModel(PreTrainedModel): | |
| config_class = BertConfig | |
| logger = logging.getLogger("StanceEncoderModel") | |
| def __init__(self, config): | |
| super().__init__(config) | |
| task_specific_params = config.task_specific_params | |
| self.num_labels = task_specific_params.get('num_labels', 3) | |
| self.mask_token_id = task_specific_params['mask_token_id'] | |
| self.verbalizer_token_ids = task_specific_params['verbalizer_token_ids'] | |
| self.clf_hidden_dim = task_specific_params.get('clf_hidden_dim', 300) | |
| self.clf_drop_prob = task_specific_params.get('clf_drop_prob', 0.2) | |
| self.clf_gelu_head = task_specific_params.get('clf_gelu_head', False) | |
| self.masked_lm = task_specific_params.get('masked_lm', True) | |
| self.masked_lm_n_tokens = task_specific_params.get('masked_lm_tokens', 1) | |
| self.masked_lm_verbalizer = task_specific_params.get('masked_lm_verbalizer', False) | |
| base_model = BertForMaskedLM(config) | |
| self.base_enc_model = base_model.bert | |
| self.lm_head = base_model.cls | |
| hidden_size_multiplier = 1 | |
| if not self.masked_lm_verbalizer: | |
| if self.clf_gelu_head: | |
| self.logger.info('using 2 layer gelu classifier head') | |
| self.classifier = torch.nn.Sequential( | |
| torch.nn.Linear(self.config.hidden_size * hidden_size_multiplier, self.clf_hidden_dim), | |
| torch.nn.Dropout(self.clf_drop_prob), | |
| torch.nn.GELU(), | |
| torch.nn.Linear(self.clf_hidden_dim, self.num_labels) | |
| ) | |
| else: | |
| raise ValueError('classification type head not specified') | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| token_type_ids: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.Tensor] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| sequence_ids: Optional[torch.Tensor] = None, | |
| ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: | |
| outputs = self.base_enc_model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask | |
| ) | |
| masked_token_filter = input_ids == self.mask_token_id | |
| masked_repr = outputs.last_hidden_state[masked_token_filter].reshape(len(input_ids), -1) | |
| if self.masked_lm_verbalizer: | |
| logits = self.lm_head(masked_repr)[:, self.verbalizer_token_ids] | |
| else: | |
| logits = self.classifier(masked_repr) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |