File size: 4,775 Bytes
31d8586
2bcfa34
 
 
 
 
 
 
 
31d8586
 
2bcfa34
 
 
31d8586
 
 
 
2bcfa34
31d8586
2bcfa34
31d8586
2bcfa34
 
 
 
 
 
 
 
 
31d8586
2bcfa34
 
31d8586
 
 
 
 
 
 
 
 
 
 
2bcfa34
 
 
31d8586
2bcfa34
 
 
 
 
 
 
 
 
 
 
 
31d8586
2bcfa34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31d8586
 
2bcfa34
31d8586
 
2bcfa34
 
 
 
 
 
 
31d8586
 
 
 
2bcfa34
31d8586
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from transformers import AutoModelForTokenClassification, AutoModel, AutoConfig
from transformers.modeling_outputs import TokenClassifierOutput
import torch
import torch.nn as nn
from torchcrf import CRF
from typing import Optional, Union, Tuple, List
import os
import json


class TransformerCRFForTokenClassification(AutoModelForTokenClassification):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.base_model = AutoModel.from_config(config=config, use_safetensors=True)
        hidden_size = config.hidden_size if hasattr(config, 'hidden_size') else 768

        self.dropout = nn.Dropout(config.hidden_dropout_prob if hasattr(config, 'hidden_dropout_prob') else 0.1)
        self.classifier = nn.Linear(hidden_size, config.num_labels)
        

        self.use_crf = config.use_crf if hasattr(config, 'use_crf') else False
        if self.use_crf:
            self.crf = CRF(num_tags=self.num_labels, batch_first=True)
        else:
            self.crf = None
            self.loss_fn = nn.CrossEntropyLoss()

        # Initialize weights and apply final processing
        self.post_init()
        

    def forward(
            self,
            input_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            token_type_ids: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            head_mask: Optional[torch.Tensor] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            labels: Optional[torch.Tensor] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.base_model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        
        loss = None
        if labels is not None:
            if self.crf is not None:
                mask = attention_mask.bool()
                labels_mask = labels != -100
                mask = mask & labels_mask
                loss = -self.crf(logits, labels, mask=mask, reduction='mean')
            else:
                loss = self.loss_fn(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states if output_hidden_states else None,
            attentions=outputs.attentions if output_attentions else None,
        )

    def save_pretrained(self, save_directory: str, **kwargs):
        """Save model with custom CRF layer"""
        # Save the config
        self.config.use_crf = self.use_crf
        self.config.save_pretrained(save_directory, safe_serialization=True)

        # Save the model weights
        super().save_pretrained(save_directory, safe_serialization=True, **kwargs)

        if self.crf is not None:
            crf_path = os.path.join(save_directory, "crf.pt")
            torch.save(self.crf.state_dict(), crf_path)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
        """Load model with custom CRF layer"""
        if 'config' in kwargs:
            config = kwargs.pop('config')
        else:
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
        
        # Ensure use_crf is set in the configuration
        if not hasattr(config, 'use_crf'):
            config.use_crf = False  # or True, depending on your default

        # Load the model
        model = super().from_pretrained(pretrained_model_name_or_path, config=config, use_safetensors=True, *model_args, **kwargs)

        # Initialize CRF if needed
        if config.use_crf:
            model.crf = CRF(num_tags=config.num_labels, batch_first=True)
            crf_path = os.path.join(pretrained_model_name_or_path, "crf.pt")
            if os.path.exists(crf_path):
                model.crf.load_state_dict(torch.load(crf_path))
        else:
            model.crf = None

        return model