import torch import torch.nn as nn from .configuration_smsbert import SMSBertConfig from transformers import pipeline, BertModel, AutoTokenizer, PretrainedConfig,PreTrainedModel, Pipeline, AutoModel,AutoModelForSequenceClassification, BertConfig class SMSBertModel(PreTrainedModel): config_class = SMSBertConfig def __init__(self, config): super().__init__(config) # self.bert = BertModel.from_pretrained(base_model_path) # self.config= kwargs # self.config = config self.bert = BertModel._from_config(config) self.dropout = nn.Dropout(0.2) self.token_classifier = nn.Linear(self.bert.config.hidden_size, 16) self.sequence_classifier = nn.Linear(self.bert.config.hidden_size, 7) # Initialize weights nn.init.kaiming_normal_(self.token_classifier.weight, mode='fan_in', nonlinearity='linear') nn.init.kaiming_normal_(self.sequence_classifier.weight, mode='fan_in', nonlinearity='linear') self.seq_labels = [ "Transaction", "Courier", "OTP", "Expiry", "Misc", "Tele Marketing", "Spam", ] self.token_class_labels = [ 'O', 'Courier Service', 'Credit', 'Date', 'Debit', 'Email', 'Expiry', 'Item', 'Order ID', 'Organization', 'OTP', 'Phone Number', 'Refund', 'Time', 'Tracking ID', 'URL', ] # base_model_path = '.' # self.tokenizer = AutoTokenizer.from_pretrained(base_model_path) # def forward(self, input_ids : torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor): # outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) # print(type(outputs)) # sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output # token_classification_logits = self.token_classifier(self.dropout(sequence_output)) # sequence_logits = self.sequence_classifier(self.dropout(pooled_output)) # return token_classification_logits, sequence_logits def forward(self, input_ids,attention_mask, token_type_ids): # out = self.tokenizer(inputs, return_tensors="pt") # outputs = outputs = self.bert(input_ids, attention_mask, token_type_ids) sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output token_classification_logits = self.token_classifier(self.dropout(sequence_output)) sequence_logits = self.sequence_classifier(self.dropout(pooled_output)) token_classification_logits = token_classification_logits.argmax(2)[0] sequence_logits = sequence_logits.argmax(1)[0] token_classification_out = [self.token_class_labels[i] for i in token_classification_logits.tolist()] seq_classification_out = self.seq_labels[sequence_logits] model_out = str({"token_classfier":token_classification_out, "sequence_classfier": seq_classification_out}) return model_out # return token_classification_logits, sequence_logits # def classify(self, inputs): # out = self.tokenizer(inputs, return_tensors="pt") # token_classification_logits, sequence_logits = self.forward(**out) # token_classification_logits = token_classification_logits.argmax(2)[0] # sequence_logits = sequence_logits.argmax(1)[0] # token_classification_out = [self.token_class_labels[i] for i in token_classification_logits.tolist()] # seq_classification_out = self.seq_labels[sequence_logits] # # return token_classification_out, seq_classification_out # return {"token_classfier":token_classification_out, "sequence_classfier": seq_classification_out}