RaduGabriel commited on
Commit
2bcfa34
·
verified ·
1 Parent(s): 10a6706

Upload custom_modeling.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. custom_modeling.py +122 -0
custom_modeling.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, AutoModel, AutoConfig
2
+ from transformers.modeling_outputs import TokenClassifierOutput
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchcrf import CRF
6
+ from typing import Optional, Union, Tuple, List
7
+ import os
8
+ import json
9
+
10
+ class BertCRFPreTrainedModel(PreTrainedModel):
11
+ """An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models."""
12
+ config_class = AutoConfig
13
+ base_model_prefix = "bert"
14
+
15
+ def _init_weights(self, module):
16
+ """Initialize the weights"""
17
+ if isinstance(module, nn.Linear):
18
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02)
19
+ if module.bias is not None:
20
+ module.bias.data.zero_()
21
+ elif isinstance(module, nn.Embedding):
22
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02)
23
+ if module.padding_idx is not None:
24
+ module.weight.data[module.padding_idx].zero_()
25
+ elif isinstance(module, nn.LayerNorm):
26
+ module.bias.data.zero_()
27
+ module.weight.data.fill_(1.0)
28
+
29
+ class BertCRFForTokenClassification(BertCRFPreTrainedModel):
30
+ def __init__(self, config):
31
+ super().__init__(config)
32
+ self.num_labels = config.num_labels
33
+ self.bert = AutoModel.from_config(config)
34
+ self.dropout = nn.Dropout(config.hidden_dropout_prob if hasattr(config, 'hidden_dropout_prob') else 0.1)
35
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
36
+
37
+ self.use_crf = config.use_crf if hasattr(config, 'use_crf') else False
38
+ if self.use_crf:
39
+ self.crf = CRF(num_tags=self.num_labels, batch_first=True)
40
+ else:
41
+ self.crf = None
42
+ self.loss_fn = nn.CrossEntropyLoss()
43
+
44
+ # Initialize weights and apply final processing
45
+ self.post_init()
46
+
47
+ def forward(
48
+ self,
49
+ input_ids: Optional[torch.Tensor] = None,
50
+ attention_mask: Optional[torch.Tensor] = None,
51
+ token_type_ids: Optional[torch.Tensor] = None,
52
+ position_ids: Optional[torch.Tensor] = None,
53
+ head_mask: Optional[torch.Tensor] = None,
54
+ inputs_embeds: Optional[torch.Tensor] = None,
55
+ labels: Optional[torch.Tensor] = None,
56
+ output_attentions: Optional[bool] = None,
57
+ output_hidden_states: Optional[bool] = None,
58
+ return_dict: Optional[bool] = None,
59
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
60
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
61
+
62
+ outputs = self.bert(
63
+ input_ids,
64
+ attention_mask=attention_mask,
65
+ token_type_ids=token_type_ids,
66
+ position_ids=position_ids,
67
+ head_mask=head_mask,
68
+ inputs_embeds=inputs_embeds,
69
+ output_attentions=output_attentions,
70
+ output_hidden_states=output_hidden_states,
71
+ return_dict=return_dict,
72
+ )
73
+
74
+ sequence_output = outputs[0]
75
+ sequence_output = self.dropout(sequence_output)
76
+ logits = self.classifier(sequence_output)
77
+
78
+ loss = None
79
+ if labels is not None:
80
+ if self.crf is not None:
81
+ mask = attention_mask.bool()
82
+ labels_mask = labels != -100
83
+ mask = mask & labels_mask
84
+ loss = -self.crf(logits, labels, mask=mask, reduction='mean')
85
+ else:
86
+ loss = self.loss_fn(logits.view(-1, self.num_labels), labels.view(-1))
87
+
88
+ if not return_dict:
89
+ output = (logits,) + outputs[2:]
90
+ return ((loss,) + output) if loss is not None else output
91
+
92
+ return TokenClassifierOutput(
93
+ loss=loss,
94
+ logits=logits,
95
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
96
+ attentions=outputs.attentions if output_attentions else None,
97
+ )
98
+
99
+ def save_pretrained(self, save_directory: str, **kwargs):
100
+ """Save model with custom CRF layer"""
101
+ # Save the config
102
+ self.config.use_crf = self.use_crf
103
+ self.config.save_pretrained(save_directory)
104
+
105
+ # Save the model weights
106
+ super().save_pretrained(save_directory, **kwargs)
107
+
108
+ if self.crf is not None:
109
+ crf_path = os.path.join(save_directory, "crf.pt")
110
+ torch.save(self.crf.state_dict(), crf_path)
111
+
112
+ @classmethod
113
+ def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
114
+ """Load model with custom CRF layer"""
115
+ model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
116
+
117
+ # Load CRF if it exists
118
+ crf_path = os.path.join(pretrained_model_name_or_path, "crf.pt")
119
+ if os.path.exists(crf_path) and model.use_crf:
120
+ model.crf.load_state_dict(torch.load(crf_path))
121
+
122
+ return model