Spaces:
Runtime error
Runtime error
| """ | |
| Source: https://github.com/ZurichNLP/recognizing-semantic-differences | |
| MIT License | |
| Copyright (c) 2023 University of Zurich | |
| """ | |
| import itertools | |
| from typing import List, Union | |
| import torch | |
| import transformers | |
| from transformers import FeatureExtractionPipeline, Pipeline | |
| from recognizers.base import DifferenceRecognizer | |
| from recognizers.utils import DifferenceSample | |
| Ngram = List[int] # A span of subword indices | |
| class FeatureExtractionRecognizer(DifferenceRecognizer): | |
| def __init__(self, | |
| model_name_or_path: str = None, | |
| pipeline: Union[FeatureExtractionPipeline, Pipeline] = None, | |
| layer: int = -1, | |
| batch_size: int = 16, | |
| ): | |
| assert model_name_or_path is not None or pipeline is not None | |
| if pipeline is None: | |
| pipeline = transformers.pipeline( | |
| model=model_name_or_path, | |
| task="feature-extraction", | |
| ) | |
| self.pipeline = pipeline | |
| self.layer = layer | |
| self.batch_size = batch_size | |
| def encode_batch(self, sentences: List[str], **kwargs) -> torch.Tensor: | |
| model_inputs = self.pipeline.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True) | |
| model_inputs = model_inputs.to(self.pipeline.device) | |
| outputs = self.pipeline.model(**model_inputs, output_hidden_states=True, **kwargs) | |
| return outputs.hidden_states[self.layer] | |
| def predict(self, | |
| a: str, | |
| b: str, | |
| **kwargs, | |
| ) -> DifferenceSample: | |
| return self.predict_all([a], [b], **kwargs)[0] | |
| def predict_all(self, | |
| a: List[str], | |
| b: List[str], | |
| **kwargs, | |
| ) -> List[DifferenceSample]: | |
| samples = [] | |
| for i in range(0, len(a), self.batch_size): | |
| samples.extend(self._predict_all( | |
| a[i:i + self.batch_size], | |
| b[i:i + self.batch_size], | |
| **kwargs, | |
| )) | |
| return samples | |
| def _predict_all(self, | |
| a: List[str], | |
| b: List[str], | |
| **kwargs, | |
| ) -> List[DifferenceSample]: | |
| raise NotImplementedError | |
| def _pool(self, token_embeddings: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |
| """ | |
| :param token_embeddings: batch x seq_len x dim | |
| :param mask: batch x seq_len; 1 if token should be included in the pooling | |
| :return: batch x dim | |
| Do only sum and do not divide by the number of tokens because cosine similarity is length-invariant. | |
| """ | |
| return torch.sum(token_embeddings * mask.unsqueeze(-1), dim=1) | |
| def _get_subwords_by_word(self, sentence: str) -> List[Ngram]: | |
| """ | |
| :return: For each word in the sentence, the positions of the subwords that make up the word. | |
| """ | |
| batch_encoding = self.pipeline.tokenizer( | |
| sentence, | |
| padding=True, | |
| truncation=True, | |
| ) | |
| subword_ids: List[List[int]] = [] | |
| for subword_idx in range(len(batch_encoding.encodings[0].word_ids)): | |
| if batch_encoding.encodings[0].word_ids[subword_idx] is None: # Special token | |
| continue | |
| char_idx = batch_encoding.encodings[0].offsets[subword_idx][0] | |
| if isinstance(self.pipeline.tokenizer, transformers.XLMRobertaTokenizerFast) or \ | |
| isinstance(self.pipeline.tokenizer, transformers.XLMRobertaTokenizer): | |
| token = batch_encoding.encodings[0].tokens[subword_idx] | |
| is_tail = not token.startswith("▁") and token not in self.pipeline.tokenizer.all_special_tokens | |
| elif isinstance(self.pipeline.tokenizer, transformers.RobertaTokenizerFast) or \ | |
| isinstance(self.pipeline.tokenizer, transformers.RobertaTokenizer): | |
| token = batch_encoding.encodings[0].tokens[subword_idx] | |
| is_tail = not token.startswith("Ġ") and token not in self.pipeline.tokenizer.all_special_tokens | |
| else: | |
| is_tail = char_idx > 0 and char_idx == batch_encoding.encodings[0].offsets[subword_idx - 1][1] | |
| if is_tail and len(subword_ids) > 0: | |
| subword_ids[-1].append(subword_idx) | |
| else: | |
| subword_ids.append([subword_idx]) | |
| return subword_ids | |
| def _get_ngrams(self, subwords_by_word: List[Ngram]) -> List[Ngram]: | |
| """ | |
| :return: For each subword ngram in the sentence, the positions of the subwords that make up the ngram. | |
| """ | |
| subwords = list(itertools.chain.from_iterable(subwords_by_word)) | |
| # Always return at least one ngram (reduce n if necessary) | |
| min_n = min(self.min_n, len(subwords)) | |
| ngrams = [] | |
| for n in range(min_n, self.max_n + 1): | |
| for i in range(len(subwords) - n + 1): | |
| ngrams.append(subwords[i:i + n]) | |
| return ngrams | |
| def _subword_labels_to_word_labels(self, subword_labels: torch.Tensor, subwords_by_words: List[Ngram]) -> List[float]: | |
| """ | |
| :param subword_labels: num_subwords | |
| :param subwords_by_words: num_words x num_subwords | |
| :return: num_words | |
| """ | |
| labels = [] | |
| for subword_indices in subwords_by_words: | |
| label = subword_labels[subword_indices].mean().item() | |
| labels.append(label) | |
| return labels | |