Token Classification
Transformers
Safetensors
MultiLabelBert
multilabel
multilabel-token-classification
custom_code
Instructions to use jvaquet/multilabel-classification-bert-ace2004 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use jvaquet/multilabel-classification-bert-ace2004 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("token-classification", model="jvaquet/multilabel-classification-bert-ace2004", trust_remote_code=True)# Load model directly from transformers import AutoModelForTokenClassification model = AutoModelForTokenClassification.from_pretrained("jvaquet/multilabel-classification-bert-ace2004", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from transformers import Pipeline | |
| import torch | |
| import torch.nn as nn | |
| MODEL_FOR_MULTILABEL_TOKEN_CLASSIFICATION = [ | |
| 'BertForMultiLabelTokenClassification' | |
| ] | |
| class MultilabelNerPipeline(Pipeline): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.check_model_type(MODEL_FOR_MULTILABEL_TOKEN_CLASSIFICATION) | |
| self.entity_types = {label[2:] for label in self.model.config.label2id} | |
| def _sanitize_parameters(self, **kwargs): | |
| preprocess_kwargs = {} | |
| if 'stride' in kwargs: | |
| preprocess_kwargs['stride'] = kwargs['stride'] | |
| postprocess_kwargs = {} | |
| if 'threshold' in kwargs: | |
| postprocess_kwargs['threshold'] = kwargs['threshold'] | |
| if 'use_hierarchy_heuristic' in kwargs: | |
| postprocess_kwargs['use_hierarchy_heuristic'] = kwargs['use_hierarchy_heuristic'] | |
| return preprocess_kwargs, {}, postprocess_kwargs | |
| def preprocess(self, inputs, stride=128): | |
| tokenized_inputs = self.tokenizer(inputs, | |
| truncation=True, | |
| padding=True, | |
| stride=stride, | |
| return_tensors='pt', | |
| return_overflowing_tokens=True, | |
| return_special_tokens_mask=True | |
| ) | |
| n_samples = tokenized_inputs.input_ids.size()[0] | |
| char_offsets = [tokenized_inputs[idx].offsets for idx in range(n_samples)] | |
| return { | |
| 'input_ids': tokenized_inputs.input_ids, | |
| 'attention_mask': tokenized_inputs.attention_mask, | |
| 'char_offsets': char_offsets, | |
| 'special_tokens_mask': tokenized_inputs.special_tokens_mask, | |
| 'text': inputs | |
| } | |
| def _forward(self, model_inputs): | |
| return { | |
| 'logits': self.model(**model_inputs).logits, | |
| 'text': model_inputs['text'], | |
| 'char_offsets': model_inputs['char_offsets'], | |
| 'special_tokens_mask': model_inputs['special_tokens_mask'] | |
| } | |
| def postprocess(self, model_outputs, threshold=0.5, use_hierarchy_heuristic=False): | |
| predictions = nn.functional.sigmoid(model_outputs['logits']) | |
| predictions[model_outputs['special_tokens_mask'] == 1] = 0 | |
| spans_single = self.extract_single_token_spans(predictions, threshold) | |
| spans_multi = self.extract_multi_token_spans(predictions, threshold) | |
| spans = self.token_spans_to_char_spans(spans_single + spans_multi, model_outputs['char_offsets'], model_outputs['text']) | |
| spans = self.deduplicate_spans(spans) | |
| if use_hierarchy_heuristic: | |
| spans = self.apply_hierarchy_heristic(spans) | |
| return spans | |
| def extract_single_token_spans(self, predictions, threshold): | |
| return [{ | |
| 'label': entity_type, | |
| 'batch': idx_batch, | |
| 'span_token': (int(idx_token), int(idx_token+1)) | |
| } | |
| for entity_type in self.entity_types | |
| for idx_batch, idx_token in zip(*torch.where(predictions[:,:, self.model.config.label2id[f'S-{entity_type}']] >= threshold)) | |
| ] | |
| def extract_multi_token_spans(self, predictions, threshold): | |
| return [{ | |
| 'label': entity_type, | |
| 'batch': idx_batch_begin, | |
| 'span_token': (int(idx_token_begin), int(idx_token_end+1)) | |
| } | |
| for entity_type in self.entity_types | |
| for idx_batch_begin, idx_token_begin in zip(*torch.where(predictions[:,:, self.model.config.label2id[f'B-{entity_type}']] >= threshold)) | |
| for idx_batch_end, idx_token_end in zip(*torch.where(predictions[:,:, self.model.config.label2id[f'E-{entity_type}']] >= threshold)) | |
| if idx_batch_begin == idx_batch_end | |
| if idx_token_begin < idx_token_end | |
| if torch.all(predictions[idx_batch_begin, idx_token_begin+1:idx_token_end, self.model.config.label2id[f'I-{entity_type}']] >= threshold) | |
| ] | |
| def token_spans_to_char_spans(self, spans, char_offsets, text): | |
| return [{ | |
| 'label': span['label'], | |
| 'span': (char_start, char_end), | |
| 'text': text[char_start:char_end] | |
| } | |
| for span in spans | |
| if (batch := span['batch']) is not None | |
| if (span_token := span['span_token']) is not None | |
| if (char_start := char_offsets[batch][span_token[0]][0]) is not None | |
| if (char_end := char_offsets[batch][span_token[1]-1][1]) is not None] | |
| def deduplicate_spans(self, spans): | |
| return [dict(tup) | |
| for tup in {tuple(span.items()) for span in spans} | |
| ] | |
| def apply_hierarchy_heristic(self, spans): | |
| def _group_spans(spans): | |
| groups = [] | |
| for span in sorted(spans, key=lambda span: span['span'][0] - span['span'][1]): | |
| found_group = False | |
| for cur_group in groups: | |
| if (cur_group['label'] == span['label'] | |
| and cur_group['start'] <= span['span'][0] | |
| and cur_group['end'] >= span['span'][1]): | |
| cur_group['spans'].append(span) | |
| found_group = True | |
| break | |
| # If no group found, make new one | |
| if not found_group: | |
| groups.append({ | |
| 'start': span['span'][0], | |
| 'end': span['span'][1], | |
| 'spans': [span], | |
| 'label': span['label'] | |
| }) | |
| return groups | |
| return_spans = [] | |
| for group in _group_spans(spans): | |
| sorted_spans = sorted(group['spans'], key=lambda span: span['span'][1] - span['span'][0]) | |
| # Collect all start and end positions | |
| span_starts = {span['span'][0] for span in sorted_spans} | |
| span_ends = {span['span'][1] for span in sorted_spans} | |
| # Except for start and end of group | |
| span_starts.discard(sorted_spans[-1]['span'][0]) | |
| span_ends.discard(sorted_spans[-1]['span'][1]) | |
| # Preserve encapsulating span | |
| cur_spans = [sorted_spans[-1]] | |
| # Iteratively add shortest span, if it covers an unused start or end point | |
| for cur_span in sorted_spans[:-1]: | |
| if len(span_starts) + len(span_ends) == 0: | |
| break | |
| if cur_span['span'][0] in span_starts \ | |
| or cur_span['span'][1] in span_ends: | |
| cur_spans.append(cur_span) | |
| span_starts.discard(cur_span['span'][0]) | |
| span_ends.discard(cur_span['span'][1]) | |
| return_spans += cur_spans | |
| return return_spans | |
| from transformers.pipelines import PIPELINE_REGISTRY | |
| from transformers import AutoModelForTokenClassification | |
| PIPELINE_REGISTRY.register_pipeline( | |
| 'multilabel-ner', | |
| pipeline_class=MultilabelNerPipeline, | |
| pt_model=AutoModelForTokenClassification, | |
| default={'pt': ('jvaquet/multilabel-classification-bert', 'main')}, | |
| type='text', | |
| ) |