Spaces:
Runtime error
Runtime error
| """ | |
| Source: https://github.com/ZurichNLP/recognizing-semantic-differences | |
| MIT License | |
| Copyright (c) 2023 University of Zurich | |
| """ | |
| from typing import List | |
| import torch | |
| from recognizers.feature_based import FeatureExtractionRecognizer | |
| from recognizers.utils import DifferenceSample, cos_sim | |
| class DiffAlign(FeatureExtractionRecognizer): | |
| def __str__(self): | |
| return f"DiffAlign(model={self.pipeline.model.name_or_path}, layer={self.layer}" | |
| def _predict_all(self, | |
| a: List[str], | |
| b: List[str], | |
| **kwargs, | |
| ) -> List[DifferenceSample]: | |
| outputs_a = self.encode_batch(a, **kwargs) | |
| outputs_b = self.encode_batch(b, **kwargs) | |
| subwords_by_words_a = [self._get_subwords_by_word(sentence) for sentence in a] | |
| subwords_by_words_b = [self._get_subwords_by_word(sentence) for sentence in b] | |
| subword_labels_a = [] | |
| subword_labels_b = [] | |
| for i in range(len(a)): | |
| cosine_similarities = cos_sim(outputs_a[i], outputs_b[i]) | |
| max_similarities_a = torch.max(cosine_similarities, dim=1).values | |
| max_similarities_b = torch.max(cosine_similarities, dim=0).values | |
| subword_labels_a.append((1 - max_similarities_a)) | |
| subword_labels_b.append((1 - max_similarities_b)) | |
| samples = [] | |
| for i in range(len(a)): | |
| labels_a = self._subword_labels_to_word_labels(subword_labels_a[i], subwords_by_words_a[i]) | |
| labels_b = self._subword_labels_to_word_labels(subword_labels_b[i], subwords_by_words_b[i]) | |
| samples.append(DifferenceSample( | |
| tokens_a=tuple(a[i].split()), | |
| tokens_b=tuple(b[i].split()), | |
| labels_a=tuple(labels_a), | |
| labels_b=tuple(labels_b), | |
| )) | |
| return samples | |