from transformers import Pipeline import torch import torch.nn as nn MODEL_FOR_MULTILABEL_TOKEN_CLASSIFICATION = [ 'BertForMultiLabelTokenClassification' ] class MultilabelNerPipeline(Pipeline): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.check_model_type(MODEL_FOR_MULTILABEL_TOKEN_CLASSIFICATION) self.entity_types = {label[2:] for label in self.model.config.label2id} def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} if 'stride' in kwargs: preprocess_kwargs['stride'] = kwargs['stride'] postprocess_kwargs = {} if 'threshold' in kwargs: postprocess_kwargs['threshold'] = kwargs['threshold'] if 'use_hierarchy_heuristic' in kwargs: postprocess_kwargs['use_hierarchy_heuristic'] = kwargs['use_hierarchy_heuristic'] return preprocess_kwargs, {}, postprocess_kwargs def preprocess(self, inputs, stride=128): tokenized_inputs = self.tokenizer(inputs, truncation=True, padding=True, stride=stride, return_tensors='pt', return_overflowing_tokens=True, return_special_tokens_mask=True ) n_samples = tokenized_inputs.input_ids.size()[0] char_offsets = [tokenized_inputs[idx].offsets for idx in range(n_samples)] return { 'input_ids': tokenized_inputs.input_ids, 'attention_mask': tokenized_inputs.attention_mask, 'char_offsets': char_offsets, 'special_tokens_mask': tokenized_inputs.special_tokens_mask, 'text': inputs } def _forward(self, model_inputs): return { 'logits': self.model(**model_inputs).logits, 'text': model_inputs['text'], 'char_offsets': model_inputs['char_offsets'], 'special_tokens_mask': model_inputs['special_tokens_mask'] } def postprocess(self, model_outputs, threshold=0.5, use_hierarchy_heuristic=False): predictions = nn.functional.sigmoid(model_outputs['logits']) predictions[model_outputs['special_tokens_mask'] == 1] = 0 spans_single = self.extract_single_token_spans(predictions, threshold) spans_multi = self.extract_multi_token_spans(predictions, threshold) spans = self.token_spans_to_char_spans(spans_single + spans_multi, model_outputs['char_offsets'], model_outputs['text']) spans = self.deduplicate_spans(spans) if use_hierarchy_heuristic: spans = self.apply_hierarchy_heristic(spans) return spans def extract_single_token_spans(self, predictions, threshold): return [{ 'label': entity_type, 'batch': idx_batch, 'span_token': (int(idx_token), int(idx_token+1)) } for entity_type in self.entity_types for idx_batch, idx_token in zip(*torch.where(predictions[:,:, self.model.config.label2id[f'S-{entity_type}']] >= threshold)) ] def extract_multi_token_spans(self, predictions, threshold): return [{ 'label': entity_type, 'batch': idx_batch_begin, 'span_token': (int(idx_token_begin), int(idx_token_end+1)) } for entity_type in self.entity_types for idx_batch_begin, idx_token_begin in zip(*torch.where(predictions[:,:, self.model.config.label2id[f'B-{entity_type}']] >= threshold)) for idx_batch_end, idx_token_end in zip(*torch.where(predictions[:,:, self.model.config.label2id[f'E-{entity_type}']] >= threshold)) if idx_batch_begin == idx_batch_end if idx_token_begin < idx_token_end if torch.all(predictions[idx_batch_begin, idx_token_begin+1:idx_token_end, self.model.config.label2id[f'I-{entity_type}']] >= threshold) ] def token_spans_to_char_spans(self, spans, char_offsets, text): return [{ 'label': span['label'], 'span': (char_start, char_end), 'text': text[char_start:char_end] } for span in spans if (batch := span['batch']) is not None if (span_token := span['span_token']) is not None if (char_start := char_offsets[batch][span_token[0]][0]) is not None if (char_end := char_offsets[batch][span_token[1]-1][1]) is not None] def deduplicate_spans(self, spans): return [dict(tup) for tup in {tuple(span.items()) for span in spans} ] def apply_hierarchy_heristic(self, spans): def _group_spans(spans): groups = [] for span in sorted(spans, key=lambda span: span['span'][0] - span['span'][1]): found_group = False for cur_group in groups: if (cur_group['label'] == span['label'] and cur_group['start'] <= span['span'][0] and cur_group['end'] >= span['span'][1]): cur_group['spans'].append(span) found_group = True break # If no group found, make new one if not found_group: groups.append({ 'start': span['span'][0], 'end': span['span'][1], 'spans': [span], 'label': span['label'] }) return groups return_spans = [] for group in _group_spans(spans): sorted_spans = sorted(group['spans'], key=lambda span: span['span'][1] - span['span'][0]) # Collect all start and end positions span_starts = {span['span'][0] for span in sorted_spans} span_ends = {span['span'][1] for span in sorted_spans} # Except for start and end of group span_starts.discard(sorted_spans[-1]['span'][0]) span_ends.discard(sorted_spans[-1]['span'][1]) # Preserve encapsulating span cur_spans = [sorted_spans[-1]] # Iteratively add shortest span, if it covers an unused start or end point for cur_span in sorted_spans[:-1]: if len(span_starts) + len(span_ends) == 0: break if cur_span['span'][0] in span_starts \ or cur_span['span'][1] in span_ends: cur_spans.append(cur_span) span_starts.discard(cur_span['span'][0]) span_ends.discard(cur_span['span'][1]) return_spans += cur_spans return return_spans from transformers.pipelines import PIPELINE_REGISTRY from transformers import AutoModelForTokenClassification PIPELINE_REGISTRY.register_pipeline( 'multilabel-ner', pipeline_class=MultilabelNerPipeline, pt_model=AutoModelForTokenClassification, default={'pt': ('jvaquet/multilabel-classification-bert', 'main')}, type='text', )