| | import numpy as np |
| | import torch |
| | import ot |
| | from otfuncs import ( |
| | compute_distance_matrix_cosine, |
| | compute_distance_matrix_l2, |
| | compute_weights_norm, |
| | compute_weights_uniform, |
| | min_max_scaling |
| | ) |
| |
|
| | class Aligner: |
| | def __init__(self, ot_type, sinkhorn, dist_type, weight_type, distortion, thresh, tau, **kwargs): |
| | self.ot_type = ot_type |
| | self.sinkhorn = sinkhorn |
| | self.dist_type = dist_type |
| | self.weight_type = weight_type |
| | self.distotion = distortion |
| | self.thresh = thresh |
| | self.tau = tau |
| | self.epsilon = 0.1 |
| | self.stopThr = 1e-6 |
| | self.numItermax = 1000 |
| | self.div_type = kwargs['div_type'] |
| |
|
| | self.dist_func = compute_distance_matrix_cosine if dist_type == 'cos' else compute_distance_matrix_l2 |
| | if weight_type == 'uniform': |
| | self.weight_func = compute_weights_uniform |
| | else: |
| | self.weight_func = compute_weights_norm |
| |
|
| | def compute_alignment_matrixes(self, s1_word_embeddigs, s2_word_embeddigs): |
| | P, Cost, log, similarity_matrix = self.compute_optimal_transport(s1_word_embeddigs, s2_word_embeddigs) |
| | print(log.keys()) |
| | if torch.is_tensor(P): |
| | P = P.to('cpu').numpy() |
| | loss = log.get('cost', 'NotImplemented') |
| |
|
| | return P, Cost, loss, similarity_matrix |
| | |
| | def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs): |
| | s1_word_embeddigs = s1_word_embeddigs.to(torch.float64) |
| | s2_word_embeddigs = s2_word_embeddigs.to(torch.float64) |
| |
|
| | C, similarity_matrix = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion) |
| | s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs) |
| |
|
| | if self.ot_type == 'ot': |
| | s1_weights = s1_weights / s1_weights.sum() |
| | s2_weights = s2_weights / s2_weights.sum() |
| | s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C) |
| |
|
| | if self.sinkhorn: |
| | P, log = ot.bregman.sinkhorn_log( |
| | s1_weights, s2_weights, C, |
| | reg=self.epsilon, stopThr=self.stopThr, |
| | numItermax=self.numItermax, log=True |
| | ) |
| | else: |
| | P, log = ot.emd(s1_weights, s2_weights, C, log=True) |
| | |
| | P = min_max_scaling(P) |
| |
|
| | elif self.ot_type == 'pot': |
| | s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C) |
| | m = np.min((np.sum(s1_weights), np.sum(s2_weights))) * self.tau |
| |
|
| | if self.sinkhorn: |
| | P, log = ot.partial.entropic_partial_wasserstein( |
| | s1_weights, s2_weights, C, |
| | reg=self.epsilon, |
| | m=m, stopThr=self.stopThr, numItermax=self.numItermax, log=True |
| | ) |
| | else: |
| | |
| | P, log = ot.partial.partial_wasserstein(s1_weights, s2_weights, C, m=m, log=True) |
| | |
| | P = min_max_scaling(P) |
| |
|
| | elif 'uot' in self.ot_type: |
| | tau = self.tau |
| |
|
| | if self.ot_type == 'uot': |
| | P, log = ot.unbalanced.sinkhorn_stabilized_unbalanced( |
| | s1_weights, s2_weights, C, reg=self.epsilon, reg_m=tau, |
| | stopThr=self.stopThr, numItermax=self.numItermax, log=True |
| | ) |
| | elif self.ot_type == 'uot-mm': |
| | P, log = ot.unbalanced.mm_unbalanced( |
| | s1_weights, s2_weights, C, reg_m=tau, div=self.div_type, |
| | stopThr=self.stopThr, numItermax=self.numItermax, log=True |
| | ) |
| | |
| | P = min_max_scaling(P) |
| |
|
| | elif self.ot_type == 'none': |
| | P = 1 - C |
| |
|
| | return P, C, log, similarity_matrix |
| |
|
| | def convert_to_numpy(self, s1_weights, s2_weights, C): |
| | if torch.is_tensor(s1_weights): |
| | s1_weights = s1_weights.to('cpu').numpy() |
| | s2_weights = s2_weights.to('cpu').numpy() |
| | if torch.is_tensor(C): |
| | C = C.to('cpu').numpy() |
| |
|
| | return s1_weights, s2_weights, C |
| |
|