Kalaoke commited on
Commit
8a733b9
·
1 Parent(s): 4f31466

Delete bert_for_sequence_classification.py

Browse files
Files changed (1) hide show
  1. bert_for_sequence_classification.py +0 -139
bert_for_sequence_classification.py DELETED
@@ -1,139 +0,0 @@
1
- import torch
2
- import transformers
3
- from torch import nn
4
- from torch.nn import CrossEntropyLoss
5
- from typing import Optional, Tuple, Union
6
- from transformers.modeling_outputs import SequenceClassifierOutput
7
- from transformers.models.bert.modeling_bert import (
8
- BertPreTrainedModel,
9
- BERT_INPUTS_DOCSTRING,
10
- _TOKENIZER_FOR_DOC,
11
- _CHECKPOINT_FOR_DOC,
12
- BERT_START_DOCSTRING,
13
- _CONFIG_FOR_DOC,
14
- _SEQ_CLASS_EXPECTED_OUTPUT,
15
- _SEQ_CLASS_EXPECTED_LOSS,
16
- BertModel,
17
- )
18
-
19
- from transformers.file_utils import (
20
- add_code_sample_docstrings,
21
- add_start_docstrings_to_model_forward,
22
- add_start_docstrings
23
- )
24
-
25
- @add_start_docstrings(
26
- """
27
- Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
28
- output) e.g. for GLUE tasks.
29
- """,
30
- BERT_START_DOCSTRING,
31
- )
32
- class BertForSequenceClassification(BertPreTrainedModel):
33
- def __init__(self, config, **kwargs):
34
- super().__init__(transformers.PretrainedConfig())
35
- #task_labels_map={"binary_classification": 2, "label_classification": 5}
36
- self.tasks = kwargs.get("tasks_map", {})
37
- self.config = config
38
-
39
- self.bert = BertModel(config)
40
- classifier_dropout = (
41
- config.classifier_dropout
42
- if config.classifier_dropout is not None
43
- else config.hidden_dropout_prob
44
- )
45
- self.dropout = nn.Dropout(classifier_dropout)
46
- ## add task specific output heads
47
- self.classifier1 = nn.Linear(
48
- config.hidden_size, self.tasks[0].num_labels
49
- )
50
- self.classifier2 = nn.Linear(
51
- config.hidden_size, self.tasks[1].num_labels
52
- )
53
-
54
- self.init_weights()
55
-
56
- @add_start_docstrings_to_model_forward(
57
- BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
58
- )
59
- @add_code_sample_docstrings(
60
- processor_class=_TOKENIZER_FOR_DOC,
61
- checkpoint=_CHECKPOINT_FOR_DOC,
62
- output_type=SequenceClassifierOutput,
63
- config_class=_CONFIG_FOR_DOC,
64
- expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
65
- expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
66
- )
67
- def forward(
68
- self,
69
- input_ids: Optional[torch.Tensor] = None,
70
- attention_mask: Optional[torch.Tensor] = None,
71
- token_type_ids: Optional[torch.Tensor] = None,
72
- position_ids: Optional[torch.Tensor] = None,
73
- head_mask: Optional[torch.Tensor] = None,
74
- inputs_embeds: Optional[torch.Tensor] = None,
75
- labels: Optional[torch.Tensor] = None,
76
- output_attentions: Optional[bool] = None,
77
- output_hidden_states: Optional[bool] = None,
78
- return_dict: Optional[bool] = None,
79
- task_ids=None,
80
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
81
- r"""
82
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
83
- Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
84
- config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
85
- If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
86
- """
87
- return_dict = (
88
- return_dict if return_dict is not None else self.config.use_return_dict
89
- )
90
-
91
- outputs = self.bert(
92
- input_ids,
93
- attention_mask=attention_mask,
94
- token_type_ids=token_type_ids,
95
- position_ids=position_ids,
96
- head_mask=head_mask,
97
- inputs_embeds=inputs_embeds,
98
- output_attentions=output_attentions,
99
- output_hidden_states=output_hidden_states,
100
- return_dict=return_dict,
101
- )
102
-
103
- pooled_output = outputs[1]
104
-
105
- pooled_output = self.dropout(pooled_output)
106
-
107
- unique_task_ids_list = torch.unique(task_ids).tolist()
108
- loss_list = []
109
- logits = None
110
- for unique_task_id in unique_task_ids_list:
111
- loss = None
112
- task_id_filter = task_ids == unique_task_id
113
-
114
- if unique_task_id == 0:
115
- logits = self.classifier1(pooled_output[task_id_filter])
116
- elif unique_task_id == 1:
117
- logits = self.classifier2(pooled_output[task_id_filter])
118
-
119
-
120
- if labels is not None:
121
- loss_fct = CrossEntropyLoss()
122
- loss = loss_fct(logits.view(-1, self.tasks[unique_task_id].num_labels), labels[task_id_filter].view(-1))
123
- loss_list.append(loss)
124
-
125
- # logits are only used for eval. and in case of eval the batch is not multi task
126
- # For training only the loss is used
127
-
128
- if loss_list:
129
- loss = torch.stack(loss_list).mean()
130
- if not return_dict:
131
- output = (logits,) + outputs[2:]
132
- return ((loss,) + output) if loss is not None else output
133
-
134
- return SequenceClassifierOutput(
135
- loss=loss,
136
- logits=logits,
137
- hidden_states=outputs.hidden_states,
138
- attentions=outputs.attentions,
139
- )