File size: 3,986 Bytes
88420b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import torch.nn as nn
from .configuration_smsbert import SMSBertConfig
from transformers import pipeline, BertModel, AutoTokenizer, PretrainedConfig,PreTrainedModel, Pipeline, AutoModel,AutoModelForSequenceClassification, BertConfig
class SMSBertModel(PreTrainedModel):
    config_class = SMSBertConfig

    def __init__(self, config):
        super().__init__(config)
        # self.bert = BertModel.from_pretrained(base_model_path)
        # self.config= kwargs
        # self.config = config
        self.bert = BertModel._from_config(config)
        self.dropout = nn.Dropout(0.2)
        self.token_classifier = nn.Linear(self.bert.config.hidden_size, 16)
        self.sequence_classifier = nn.Linear(self.bert.config.hidden_size, 7)
        
        # Initialize weights
        nn.init.kaiming_normal_(self.token_classifier.weight, mode='fan_in', nonlinearity='linear')
        nn.init.kaiming_normal_(self.sequence_classifier.weight, mode='fan_in', nonlinearity='linear')
        self.seq_labels = [
            "Transaction",
            "Courier",
            "OTP",
            "Expiry",
            "Misc",
            "Tele Marketing",
            "Spam",
        ]

        self.token_class_labels = [
            'O',
            'Courier Service',
            'Credit',
            'Date',
            'Debit',
            'Email',
            'Expiry',
            'Item',
            'Order ID',
            'Organization',
            'OTP',
            'Phone Number',
            'Refund',
            'Time',
            'Tracking ID',
            'URL',
        ]
        # base_model_path = '.'  
        # self.tokenizer = AutoTokenizer.from_pretrained(base_model_path)

    # def forward(self, input_ids : torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor):
    #     outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    #     print(type(outputs))
    #     sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
        
    #     token_classification_logits = self.token_classifier(self.dropout(sequence_output))
    #     sequence_logits = self.sequence_classifier(self.dropout(pooled_output))
        
    #     return token_classification_logits, sequence_logits
    def forward(self, input_ids,attention_mask, token_type_ids):
        # out = self.tokenizer(inputs, return_tensors="pt")
        # outputs = 
        outputs = self.bert(input_ids, attention_mask, token_type_ids)
        sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
        
        token_classification_logits = self.token_classifier(self.dropout(sequence_output))
        sequence_logits = self.sequence_classifier(self.dropout(pooled_output))
        
        token_classification_logits = token_classification_logits.argmax(2)[0]
        sequence_logits = sequence_logits.argmax(1)[0]
        token_classification_out = [self.token_class_labels[i] for i in token_classification_logits.tolist()]
        seq_classification_out = self.seq_labels[sequence_logits]
        model_out = str({"token_classfier":token_classification_out, "sequence_classfier": seq_classification_out})
        return model_out
        # return token_classification_logits, sequence_logits

    # def classify(self, inputs):
    #     out = self.tokenizer(inputs, return_tensors="pt")
    #     token_classification_logits, sequence_logits  = self.forward(**out)
    #     token_classification_logits = token_classification_logits.argmax(2)[0]
    #     sequence_logits = sequence_logits.argmax(1)[0]
    #     token_classification_out = [self.token_class_labels[i] for i in token_classification_logits.tolist()]
    #     seq_classification_out = self.seq_labels[sequence_logits]
    #     # return token_classification_out, seq_classification_out
    #     return {"token_classfier":token_classification_out, "sequence_classfier": seq_classification_out}