File size: 3,676 Bytes
9725a75
 
 
858483a
9725a75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
858483a
9725a75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from typing import Dict, List, Any
import torch
import torch.nn as nn
import json
from transformers import pipeline, BertModel, AutoTokenizer, PretrainedConfig

class EndpointHandler():
    def __init__(self, path=""):
        # self.pipeline = pipeline("text-classification",model=path)
        self.model = CustomModel("test_bert_config.json")
        self.model.load_state_dict(torch.load("model3.pth"))

    def __call__(self, data: Dict[str, Any])-> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str`)
            date (:obj: `str`)
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        # get inputs
        inputs = data.pop("inputs",data)
        # date = data.pop("date", None)

        # check if date exists and if it is a holiday
        # if date is not None and date in self.holidays:
        #   return [{"label": "happy", "score": 1}]


        # run normal prediction
        prediction = self.model.classify(inputs)
        prediction = json.dumps(prediction)
        return prediction
    
class CustomModel(nn.Module):
    def __init__(self, bert_config):
        super(CustomModel, self).__init__()
        # self.bert = BertModel.from_pretrained(base_model_path)
        self.bert = BertModel._from_config(PretrainedConfig.from_json_file(bert_config))
        self.dropout = nn.Dropout(0.2)
        self.token_classifier = nn.Linear(self.bert.config.hidden_size, 16)
        self.sequence_classifier = nn.Linear(self.bert.config.hidden_size, 7)
        
        # Initialize weights
        nn.init.kaiming_normal_(self.token_classifier.weight, mode='fan_in', nonlinearity='linear')
        nn.init.kaiming_normal_(self.sequence_classifier.weight, mode='fan_in', nonlinearity='linear')
        self.seq_labels = [
            "Transaction",
            "Courier",
            "OTP",
            "Expiry",
            "Misc",
            "Tele Marketing",
            "Spam",
        ]

        self.token_class_labels = [
            'O',
            'Courier Service',
            'Credit',
            'Date',
            'Debit',
            'Email',
            'Expiry',
            'Item',
            'Order ID',
            'Organization',
            'OTP',
            'Phone Number',
            'Refund',
            'Time',
            'Tracking ID',
            'URL',
        ]
        base_model_path = '.'  
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_path)

    def forward(self, input_ids : torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
        
        token_logits = self.token_classifier(self.dropout(sequence_output))
        sequence_logits = self.sequence_classifier(self.dropout(pooled_output))
        
        return token_logits, sequence_logits
    def classify(self, inputs):
        out = self.tokenizer(inputs, return_tensors="pt")
        token_classification_logits, sequence_logits  = self.forward(**out)
        token_classification_logits = token_classification_logits.argmax(2)[0]
        sequence_logits = sequence_logits.argmax(1)[0]
        token_classification_out = [self.token_class_labels[i] for i in token_classification_logits.tolist()]
        seq_classification_out = self.seq_labels[sequence_logits]
        # return token_classification_out, seq_classification_out
        return {"token_classfier":token_classification_out, "sequence_classfier": seq_classification_out}