File size: 15,078 Bytes
8d3380d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
import torch
import torch.nn as nn
from transformers import DistilBertForTokenClassification, AutoTokenizer, AutoModelForTokenClassification
from torch.utils.data import Dataset, DataLoader, TensorDataset
import json
import gc

class BertNER(nn.Module):
    """

        A custom PyTorch Module for Named Entity Recognition (NER) using DistilBertForTokenClassification.

    """
    def __init__(self,token_dims): 
        """

        Initializes the BertNER model.

        

        Parameters:

        token_dims (int): The number of unique tokens/labels in the NER task.

        """
        super(BertNER,self).__init__()
        if type(token_dims) !=  int:
            raise TypeError("Token Dimensions should be an integer")
        if token_dims <= 0:
            raise ValueError("Dimension should atleast be more than 1")
        
        self.pretrained_model = DistilBertForTokenClassification.from_pretrained('distilbert-base-uncased',num_labels=token_dims)
        
    def forward(self,input_ids,attention_mask,labels=None):
        """

        Forward pass of the model.

        

        Parameters:

        input_ids (torch.Tensor): Tensor of token ids to be fed to DistilBERT.

        attention_mask (torch.Tensor): Tensor indicating which tokens should be attended to by the model.

        labels (torch.Tensor, optional): Tensor of actual labels for computing loss. If None, the model returns logits.

        

        Returns:

        The model's output, which varies depending on whether labels are provided.

        """
        if labels == None:
            out = self.pretrained_model(input_ids=input_ids,attention_mask=attention_mask)
        
        out = self.pretrained_model(input_ids=input_ids,attention_mask=attention_mask,labels=labels)
        
        return out 

class SentenceDataset(TensorDataset):
    """

    Custom Dataset class for sentences, handling tokenization and preparing inputs for the NER model.

    """
    def __init__(self, sentences, tokenizer, max_length=256):
        """

        Initializes the SentenceDataset.

        

        Parameters:

        sentences (list of str): The list of sentences to be processed.

        tokenizer (transformers.PreTrainedTokenizer): Tokenizer for converting sentences to model inputs.

        max_length (int): Maximum length of the tokenized output.

        """
        self.sentences = [sentence.split() for sentence in sentences]
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.text = self.tokenizer(sentences, padding='max_length', max_length=self.max_length, truncation=True, return_tensors="pt",is_split_into_words=True)

    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, idx):
        """

        Retrieves an item from the dataset by index.

        

        Parameters:

        idx (int): Index of the item to retrieve.

        

        Returns:

        A dictionary containing input_ids, attention_mask, word_ids, and the original sentences.

        """
        sentence = self.sentences[idx]
        encoded_sentence = self.tokenizer(sentence, padding='max_length', max_length=self.max_length, truncation=True, return_tensors="pt", is_split_into_words=True)
        #During __getitem__ call the tokenized_sentence ('encoded_sentence') does not consider it to be tokenized by fast tokenizer, hence word_ids will not be given when accessed through data loader
        return {"input_ids":encoded_sentence.input_ids.squeeze(0),"attention_mask":encoded_sentence.attention_mask.squeeze(0),'word_ids':[-1 if x is None else x for x in encoded_sentence.word_ids()],"sentences":self.sentences}
    
class NERWrapper:
    """

    A wrapper class for the Named Entity Recognition (NER) model, simplifying the process of model loading,

    prediction, and utility functions.

    """
    def __init__(self, model_path, idx2tag_path, tokenizer_path='distilbert-base-uncased', token_dims=17):
        """

        Initializes the NERWrapper.

        

        Parameters:

        model_path (str): Path to the pre-trained NER model.

        idx2tag_path (str): Path to the index-to-tag mapping file, for decoding model predictions.

        tokenizer_path (str): Path or identifier for the tokenizer to be used.

        token_dims (int): The number of unique tokens/labels in the NER task.

        """
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,use_fast=True)
        self.model = BertNER(token_dims=token_dims)
        self.idx2tag = self.load_idx2tag(idx2tag_path)
        self.load_model(model_path)
        
    def load_model(self, model_path):
        """

        Loads the model from a specified path.

        

        Parameters:

        model_path (str): Path to the pre-trained NER model.

        """
        map_location = "cuda" if torch.cuda.is_available() else "cpu"
        checkpoint = torch.load(model_path,map_location=map_location)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        
    def load_idx2tag(self, idx2tag_path):
        """

        Loads the index-to-tag mapping from a specified path.

        

        Parameters:

        idx2tag_path (str): Path to the index-to-tag mapping file.

        

        Returns:

        dict: A dictionary mapping indices to tags.

        """
        with open(idx2tag_path, 'r') as file:
            idx2tag = json.load(file)
        def _jsonKeys2int(x):
            if isinstance(x, dict):
                return {int(k):v for k,v in x.items()}
            return x
        return _jsonKeys2int(idx2tag)
    
    def align_word_ids(self,texts, input_tensor,label_all_tokens=False):
        """

        Aligns word IDs with their corresponding labels, useful for creating a consistent format for model inputs.

        

        Parameters:

        texts (list of str): The original texts used for prediction.

        input_tensor (torch.Tensor): Tensor containing word IDs.

        label_all_tokens (bool): Whether to label all tokens or only the first token of each word.

        

        Returns:

        torch.Tensor: Tensor of aligned label IDs.

        """
        # Initialize an empty tensor for all_label_ids with the same shape and type as input_tensor but empty
        all_label_ids = []

        # Iterate through each row in the input_tensor
        for i, word_ids in enumerate(input_tensor):
            previous_word_idx = None
            label_ids = []
            # Iterate through each word_idx in the word_ids tensor
            for word_idx in word_ids:
                # Convert tensor to Python int for comparison
                word_idx = word_idx.item()
                if word_idx == -1:
                    label_ids.append(-100)
                elif word_idx != previous_word_idx:
                    label_ids.append(1)
                else:
                    label_ids.append(1 if label_all_tokens else -100)
                previous_word_idx = word_idx

            # Convert label_ids list to a tensor and assign it to the corresponding row in all_label_ids
            all_label_ids.append(label_ids)
        return all_label_ids

    def evaluate_text(self, sentences):
        """

        Evaluates texts using the NER model, returning the prediction results.

        

        Parameters:

        sentences (list of str): List of sentences to evaluate.

        

        Returns:

        list of str: The modified sentences with identified entities replaced with special tokens (e.g., <PER>).

        """
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(device)
        dataset = SentenceDataset(sentences,self.tokenizer)
        dataloader = DataLoader(dataset,batch_size=32,shuffle=False)
        predictions = [] 
        for data in dataloader:
            #Load the attention mask and the input ids 
            mask = data['attention_mask'].to(device)
            input_id = data['input_ids'].to(device)
            # Creates a tensor of word IDs for aligning model predictions with words.
            concatenated_tensor = torch.stack((data['word_ids'])).t()
            label_ids = torch.Tensor(self.align_word_ids(data['sentences'][0],concatenated_tensor)).to(device)
            output = self.model(input_id, mask, None)
            logits = output.logits
            for i in range(logits.shape[0]):
                 # Filters logits for each item in the batch, removing those not associated with actual words.
                logits_clean = logits[i][label_ids[i] != -100]
                # Determines the most likely label for each token and stores the result.
                predictions.append(logits_clean.argmax(dim=1).tolist())
            del mask,input_id,label_ids
            word_ids = []
            gc.collect()
            torch.cuda.empty_cache()
            prediction_label = [[self.idx2tag[i] for i in prediction] for prediction in predictions]
            
        return self.replace_sentence_with_tokens([sentence.split() for sentence in sentences],prediction_label)
    
    def replace_sentence_with_tokens(self,sentences,prediction_labels):
        """

        Replaces identified entities in sentences with special tokens based on the model's predictions.

        

        Parameters:

        sentences (list of list of str): Tokenized sentences.

        prediction_labels (list of list of str): Labels predicted by the model for each token.

        

        Returns:

        list of str: Modified sentences with entities replaced by special tokens.

        """
        modified_sentences = []
        for sentence, tags in zip(sentences, prediction_labels):
            words = sentence  # Split the sentence into words
            modified_sentence = [] # Initializes an empty list for the current modified sentence.
            skip_next = False  # A flag used to indicate whether to skip the next word (used for entities spanning multiple tokens).
            for i,(word,tag) in enumerate(zip(words,tags)):
                if skip_next:
                    skip_next = False
                    continue #Skip the current word
                if tag == 'B-per':
                    modified_sentence.append('<PER>')  
                     # Checks if the next word is part of the same entity (continuation of a person's name).
                    if i + 1 < len(tags) and tags[i + 1] == 'I-per':
                        skip_next = True  # Skip the next word if it's part of the same entity
                elif tag == 'I-per':
                    pass
                elif tag != 'I-per':
                    modified_sentence.append(word)
                    
            modified_sentences.append(" ".join(modified_sentence))
        
        return modified_sentences

class NextPassNERWrapper:
    """

    This class wraps around a pretrained BERT model for Named Entity Recognition (NER) tasks,

    simplifying the process of sentence processing, entity recognition, and sentence reconstruction

    with entity tags.

    """
    def __init__(self):
        """

        Initializes the wrapper by loading a pretrained tokenizer and model from Hugging Face's

        transformers library specifically designed for NER. It also sets up the device for model

        computation (GPU if available, otherwise CPU) and establishes a mapping from model output

        indices to entity types.

        """
        self.tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
        self.model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model.to(self.device)
        self.entity_map = {
            0: "O",
            1: "B-MISC",
            2: "I-MISC",
            3: "B-PER",
            4: "I-PER",
            5: "B-ORG",
            6: "I-ORG",
            7: "B-LOC",
            8: "I-LOC",
        }

    def process_sentences(self, sentences):
        """

        Processes input sentences to identify named entities and reconstructs the sentences

        by tagging entities or modifying tokens based on the model's predictions. It leverages

        a custom dataset and DataLoader for efficient batch processing.

        

        Parameters:

        sentences (list of str): The sentences to be processed for named entity recognition.

        

        Returns:

        list of str: The list of processed sentences with entities tagged or tokens modified.

        """
        dataset = SentenceDataset(sentences,self.tokenizer)
        dataloader = DataLoader(dataset,batch_size=32,shuffle=False)
        paragraph = []
        for data in dataloader:
            input_ids = data['input_ids'].to(self.device)
            attention_mask = data['attention_mask'].to(self.device)
            with torch.no_grad():
                outputs = self.model(input_ids, attention_mask=attention_mask).logits
                
                word_ids = torch.stack((data['word_ids'])).t()
                tokens = [self.tokenizer.convert_ids_to_tokens(X) for X in input_ids.cpu().numpy()]
                predictions = torch.argmax(outputs,dim=2).cpu().numpy()
                skip_next = False
                for word_id,tokens_single,prediction in zip(word_ids,tokens,predictions): 
                    reconstructed_tokens = []
                    for word_id_token, token, prediction_token in zip(word_id, tokens_single, prediction):
                        if word_id is None or token in ["[CLS]", "[SEP]", "[PAD]"] or skip_next:
                            skip_next = False
                            continue

                        entity = self.entity_map[prediction_token]

                        if entity in ["B-PER", "I-PER"] and (reconstructed_tokens[-1] != "<PER>" if reconstructed_tokens else True):
                            reconstructed_tokens.append("<PER>")
                        elif entity not in ["B-PER", "I-PER"]:
                            if token.startswith("##"):
                                if(len(reconstructed_tokens) > 1 and reconstructed_tokens[-2] == '<'):
                                    reconstructed_tokens[-1] = '<' + reconstructed_tokens[-1] + token[2:] + '>'
                                    reconstructed_tokens.pop(-2)
                                    skip_next = True
                                else:
                                    reconstructed_tokens[-1] = reconstructed_tokens[-1] + token[2:]
                            else:
                                reconstructed_tokens.append(token.strip())

                    detokenized_sentence = " ".join(reconstructed_tokens)
                    paragraph.append(detokenized_sentence)
        return paragraph