File size: 3,123 Bytes
d9c317f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import os
import torch.nn as nn
from transformers import RobertaPreTrainedModel, RobertaModel, AutoConfig
from transformers.modeling_outputs import SequenceClassifierOutput

class TransformerForABSA(RobertaPreTrainedModel):
    base_model_prefix = "roberta"

    def __init__(self, config):
        super().__init__(config)
        self.roberta = RobertaModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # Thêm lớp "none" vào num_sentiments
        self.sentiment_classifiers = nn.ModuleList([
            nn.Linear(config.hidden_size, config.num_sentiments + 1)  # +1 cho "none"
            for _ in range(config.num_aspects)
        ])
        self.init_weights()

    def forward(

        self,

        input_ids=None,

        attention_mask=None,

        labels=None,

        return_dict=None

    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        outputs = self.roberta(input_ids, attention_mask=attention_mask, return_dict=return_dict)
        pooled = self.dropout(outputs.pooler_output)  # [B, H]
        all_logits = torch.stack([cls(pooled) for cls in self.sentiment_classifiers], dim=1)  # [B, A, S+1]

        loss = None
        if labels is not None:
            # labels: [B, A], với lớp "none" thay vì -100
            B, A, _ = all_logits.size()
            logits_flat = all_logits.view(-1, all_logits.size(-1))
            targets_flat = labels.view(-1)
            loss_fct = nn.CrossEntropyLoss()  # Không dùng ignore_index
            loss = loss_fct(logits_flat, targets_flat)

        if not return_dict:
            return ((loss, all_logits) + outputs[2:]) if loss is not None else (all_logits,) + outputs[2:]
        return SequenceClassifierOutput(
            loss=loss,
            logits=all_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def save_pretrained(self, save_directory: str, **kwargs):
            """

            HuggingFace Trainer đôi khi truyền vào state_dict=..., nên ta

            chấp nhận thêm **kwargs để không vướng lỗi.

            """
            # 1) Lưu phần backbone (encoder)
            self.roberta.save_pretrained(save_directory, **kwargs)  

            # 2) Cập nhật config rồi lưu
            config = self.roberta.config
            config.num_aspects = len(self.sentiment_classifiers)
            config.num_sentiments = self.sentiment_classifiers[0].out_features
            config.auto_map       = {"AutoModel": "models.TransformerForABSA"}
            config.save_pretrained(save_directory, **kwargs)

            # 3) Lưu toàn bộ state_dict (bao gồm cả 2 head) — 
            #    nếu Trainer đã truyền state_dict trong kwargs, có thể dùng luôn,
            #    nếu không, lấy từ self.state_dict()
            sd = kwargs.get("state_dict", None) or self.state_dict()
            torch.save(sd, os.path.join(save_directory, "pytorch_model.bin"))