| | import pandas as pd |
| | import numpy as np |
| | from tqdm import tqdm |
| | from copy import copy |
| | from collections import Counter |
| | import torch |
| | from zipfile import ZipFile |
| | import pickle |
| | from io import BytesIO |
| |
|
| | from .match_groups import MatchGroups |
| |
|
| |
|
| | class Embeddings(torch.nn.Module): |
| | """ |
| | Stores embeddings for a fixed array of strings and provides methods for |
| | clustering the strings to create MatchGroups objects according to different |
| | algorithms. |
| | """ |
| | def __init__(self,strings,V,score_model,weighting_function,counts,device='cpu'): |
| | super().__init__() |
| |
|
| | self.strings = np.array(list(strings)) |
| | self.string_map = {s:i for i,s in enumerate(strings)} |
| | self.V = V |
| | self.counts = counts |
| | self.w = weighting_function(counts) |
| | self.score_model = score_model |
| | self.weighting_function = weighting_function |
| | self.device = device |
| |
|
| | self.to(device) |
| |
|
| | def __repr__(self): |
| | return f'<nama.Embeddings containing {self.V.shape[1]}-d vectors for {len(self)} strings' |
| |
|
| | def to(self,device): |
| | super().to(device) |
| | self.V = self.V.to(device) |
| | self.counts = self.counts.to(device) |
| | self.w = self.w.to(device) |
| | self.score_model.to(device) |
| | self.device = device |
| |
|
| | def save(self,f): |
| | """ |
| | Save embeddings in a simple custom zipped archive format (torch.save |
| | works too, but it requires huge amounts of memory to serialize large |
| | embeddings objects). |
| | """ |
| | with ZipFile(f,'w') as zip: |
| |
|
| | |
| | zip.writestr('score_model.pkl',pickle.dumps(self.score_model)) |
| |
|
| | |
| | zip.writestr('weighting_function.pkl',pickle.dumps(self.weighting_function)) |
| |
|
| | |
| | strings_df = pd.DataFrame().assign( |
| | string=self.strings, |
| | count=self.counts.to('cpu').numpy()) |
| | zip.writestr('strings.csv',strings_df.to_csv(index=False)) |
| |
|
| | |
| | byte_io = BytesIO() |
| | np.save(byte_io,self.V.to('cpu').numpy(),allow_pickle=False) |
| | zip.writestr('V.npy',byte_io.getvalue()) |
| |
|
| | def __getitem__(self,arg): |
| | """ |
| | Slice a Match Groups object |
| | """ |
| | if isinstance(arg,slice): |
| | i = arg |
| | elif isinstance(arg, MatchGroups): |
| | return self[arg.strings()] |
| | elif hasattr(arg,'__iter__'): |
| | |
| | string_map = self.string_map |
| | i = [string_map[s] for s in arg] |
| |
|
| | if i == list(range(len(self))): |
| | |
| | return copy(self) |
| | else: |
| | raise ValueError(f'Unknown slice input type ({type(input)}). Can only slice Embedding with a slice, match group, or iterable.') |
| |
|
| | new = copy(self) |
| | new.strings = self.strings[i] |
| | new.V = self.V[i] |
| | new.counts = self.counts[i] |
| | new.w = self.w[i] |
| | new.string_map = {s:i for i,s in enumerate(new.strings)} |
| |
|
| | return new |
| |
|
| | def embed(self,grouping): |
| | """ |
| | Construct updated Embeddings with counts from the input MatchGroups |
| | """ |
| | new = self[grouping] |
| | new.counts = torch.tensor([grouping.counts[s] for s in new.strings],device=self.device) |
| | new.w = new.weighting_function(new.counts) |
| |
|
| | return new |
| |
|
| | def __len__(self): |
| | return len(self.strings) |
| |
|
| | def _group_to_ids(self,grouping): |
| | group_id_map = {g:i for i,g in enumerate(grouping.groups.keys())} |
| | group_ids = torch.tensor([group_id_map[grouping[s]] for s in self.strings]).to(self.device) |
| | return group_ids |
| |
|
| | def _ids_to_group(self,group_ids): |
| | if isinstance(group_ids,torch.Tensor): |
| | group_ids = group_ids.to('cpu').numpy() |
| |
|
| | strings = self.strings |
| | counts = self.counts.to('cpu').numpy() |
| |
|
| | |
| | g_sort = np.lexsort((counts,group_ids)) |
| | group_ids = group_ids[g_sort] |
| | strings = strings[g_sort] |
| | counts = counts[g_sort] |
| |
|
| | |
| | split_locs = np.nonzero(group_ids[1:] != group_ids[:-1])[0] + 1 |
| |
|
| | |
| | groups = np.split(strings,split_locs) |
| |
|
| | |
| | grouping = MatchGroups() |
| | grouping.counts = Counter({s:int(c) for s,c in zip(strings,counts)}) |
| | grouping.labels = {s:g[-1] for g in groups for s in g} |
| | grouping.groups = {g[-1]:list(g) for g in groups} |
| |
|
| | return grouping |
| |
|
| | @torch.no_grad() |
| | def _fast_unite_similar(self,group_ids,threshold=0.5,progress_bar=True,batch_size=64): |
| |
|
| | V = self.V |
| | cos_threshold = self.score_model.score_to_cos(threshold) |
| |
|
| | for batch_start in tqdm(range(0,len(self),batch_size), |
| | delay=1,desc='Predicting matches',disable=not progress_bar): |
| |
|
| | i_slice = slice(batch_start,batch_start+batch_size) |
| | j_slice = slice(batch_start+1,None) |
| |
|
| | g_i = group_ids[i_slice] |
| | g_j = group_ids[j_slice] |
| |
|
| | |
| | batch_matched = (V[i_slice]@V[j_slice].T >= cos_threshold) \ |
| | * (g_i[:,None] != g_j[None,:]) |
| |
|
| | for k,matched in enumerate(batch_matched): |
| | if matched.any(): |
| | |
| | matched_groups = g_j[matched] |
| |
|
| | |
| | ids_to_group = torch.isin(group_ids,matched_groups) |
| |
|
| | |
| | group_ids[ids_to_group] = g_i[k].clone() |
| |
|
| | return self._ids_to_group(group_ids) |
| |
|
| | @torch.no_grad() |
| | def unite_similar(self, |
| | threshold=0.5, |
| | group_threshold=None, |
| | always_match=None, |
| | never_match=None, |
| | batch_size=64, |
| | progress_bar=True, |
| | always_never_conflicts='warn', |
| | return_united=False): |
| |
|
| | """ |
| | Unite embedding strings according to predicted pairwise similarity. |
| | |
| | - "theshold" sets the minimimum match similarity required to unite two strings. |
| | - Note that strings with similarity<threshold can end up matched if they are |
| | linked by a chain of sufficiently similar strings (matching is transitive). |
| | "group_threshold" can be used to add an additional constraing on the minimum |
| | similarity within each group. |
| | - "group_threshold" sets the minimum similarity required within a single group. |
| | - "always_match" takes any argument that can be used to unite strings. These |
| | strings will always be matched. |
| | - "never_match" takes a set, or a list of sets, where each set indicates two or |
| | more strings that should never be united with each other (these strings may |
| | still be united with other strings). |
| | - "always_never_conflicts" determines how to handle conflicts between |
| | "always_match" and "never_match": |
| | - always_never_conflicts="warn": Check for conflicts and print a warning |
| | if any are found (default) |
| | - always_never_conflicts="raise": Check for conflicts and raise an error |
| | if any are found |
| | - always_never_conflicts="ignore": Do not check for conflicts ("always_match" |
| | will take precedence) |
| | |
| | If "group_threshold" or "never_match" arguments are supplied, strings pairs are |
| | united in order of similarity. Highest similarity strings are matched first, and |
| | before each time a new pair of strings is united, the function checks if this will |
| | result in grouping any two strings with similarity<group_threshold. If so, this |
| | pair is skipped. This version of the algorithm requires more memory and processing |
| | time, but guaruntees deterministic output that is consistent with the constraints. |
| | |
| | returns: MatchGroups object |
| | """ |
| | if group_threshold and group_threshold < threshold: |
| | raise ValueError('group_threshold must be greater than or equal to threshold') |
| |
|
| | group_ids = torch.arange(len(self)).to(self.device) |
| | |
| | if always_match is not None: |
| | always_grouping = (MatchGroups(self.strings) |
| | .unite(always_match)) |
| | always_match_labels = always_grouping.labels |
| |
|
| |
|
| | |
| | if not (return_united or group_threshold or (never_match is not None)): |
| | if always_match is not None: |
| | group_ids = self._group_to_ids(always_grouping) |
| |
|
| | return self._fast_unite_similar( |
| | group_ids=group_ids, |
| | threshold=threshold, |
| | batch_size=batch_size, |
| | progress_bar=progress_bar) |
| |
|
| | if never_match is not None: |
| | |
| | if all(isinstance(s,str) for s in never_match): |
| | never_match = [never_match] |
| |
|
| | if always_match is not None: |
| |
|
| | assert always_never_conflicts in ['raise','warn','ignore'] |
| | |
| | if always_never_conflicts != 'ignore': |
| |
|
| | |
| | conflicts = [] |
| | for i,g in enumerate(never_match): |
| | g = sorted(list(g)) |
| | g_labels = [always_match_labels.get(s,s) for s in g] |
| | if len(set(g_labels)) < len(g): |
| | df = (pd.DataFrame() |
| | .assign( |
| | string=g, |
| | never_match_group=i, |
| | always_match_group=g_labels |
| | )) |
| | conflicts.append(df) |
| |
|
| | if conflicts: |
| | conflicts_df = pd.concat(conflicts) |
| |
|
| | if always_never_conflicts == 'warn': |
| | print(f'Warning: The following never_match groups are in conflict with always_match groups:\n{conflicts_df}') |
| | print('Conflicted never_match relationships will be ignored') |
| | else: |
| | raise ValueError(f'The following never_match groups are in conflict with always_match groups\n{conflicts_df}') |
| | |
| |
|
| | |
| | |
| | never_match = [{always_match_labels[s] for s in g if s in always_match_labels} for g in never_match] |
| | |
| | else: |
| | |
| | never_match = [set(s) for s in never_match] |
| |
|
| | |
| | V = self.V |
| | cos_threshold = self.score_model.score_to_cos(threshold) |
| | if group_threshold is not None: |
| | separate_cos = self.score_model.score_to_cos(group_threshold) |
| |
|
| | |
| | matches = [] |
| | cos_scores = [] |
| | for batch_start in tqdm(range(0,len(self),batch_size), |
| | desc='Scoring pairs', |
| | delay=1,disable=not progress_bar): |
| |
|
| | i_slice = slice(batch_start,batch_start+batch_size) |
| | j_slice = slice(batch_start+1,None) |
| |
|
| | |
| | batch_cos = V[i_slice]@V[j_slice].T |
| |
|
| | |
| | |
| | batch_cos = torch.triu(batch_cos) |
| |
|
| | bi,bj = torch.nonzero(batch_cos >= cos_threshold,as_tuple=True) |
| |
|
| | if len(bi): |
| | |
| | i = bi + batch_start |
| | j = bj + batch_start + 1 |
| |
|
| | cos = batch_cos[bi,bj] |
| |
|
| | |
| | unmatched = group_ids[i] != group_ids[j] |
| | i = i[unmatched] |
| | j = j[unmatched] |
| | cos = cos[unmatched] |
| |
|
| | if len(i): |
| | batch_matches = torch.hstack([i[:,None],j[:,None]]) |
| |
|
| | matches.append(batch_matches.to('cpu').numpy()) |
| | cos_scores.append(cos.to('cpu').numpy()) |
| |
|
| | |
| | |
| | united = [] |
| | if matches: |
| | matches = np.vstack(matches) |
| | cos_scores = np.hstack(cos_scores).T |
| |
|
| | |
| | m_sort = cos_scores.argsort()[::-1] |
| | matches = matches[m_sort] |
| |
|
| | if return_united: |
| | |
| | cos_scores_df = pd.DataFrame(matches,columns=['i','j']) |
| | cos_scores_df['cos'] = cos_scores[m_sort] |
| |
|
| | |
| | matches = torch.tensor(matches).to(self.device) |
| | |
| | |
| | if never_match is not None: |
| | never_match_map = {s:sep for sep in never_match for s in sep} |
| | |
| | if always_match is not None: |
| | |
| | never_match_array = np.array([never_match_map.get(always_match_labels[s],set()) for s in self.strings]) |
| | else: |
| | never_match_array = np.array([never_match_map.get(s,set()) for s in self.strings]) |
| | |
| |
|
| | n_matches = matches.shape[0] |
| | with tqdm(total=n_matches,desc='Uniting matches', |
| | delay=1,disable=not progress_bar) as p_bar: |
| |
|
| | while len(matches): |
| |
|
| | |
| | match_pair = matches[0] |
| | matches = matches[1:] |
| |
|
| | |
| | g = group_ids[match_pair] |
| | g0 = group_ids == g[0] |
| | g1 = group_ids == g[1] |
| |
|
| | |
| | to_unite = g0 | g1 |
| |
|
| | |
| | singletons = to_unite.sum() < 3 |
| |
|
| | |
| | unite_ok = True |
| |
|
| | |
| | if never_match is not None: |
| | never_0 = never_match_array[match_pair[0]] |
| | never_1 = never_match_array[match_pair[1]] |
| |
|
| | if never_0 and never_1 and (never_0 & never_1): |
| | |
| | |
| | unite_ok = False |
| |
|
| | |
| | |
| | if unite_ok and group_threshold and not singletons: |
| | V0 = V[g0,:] |
| | V1 = V[g1,:] |
| |
|
| | unite_ok = (V0@V1.T).min() >= separate_cos |
| |
|
| |
|
| | if unite_ok: |
| |
|
| | |
| | group_ids[to_unite] = g[0] |
| |
|
| | if never_match and (never_0 or never_1): |
| | |
| | never_match_array[to_unite.detach().cpu().numpy()] = never_0 | never_1 |
| | |
| | |
| | |
| | if not singletons: |
| | |
| | matches = matches[group_ids[matches[:,0]] != group_ids[matches[:,1]]] |
| |
|
| | if return_united: |
| | match_record = np.empty(4,dtype=int) |
| | match_record[:2] = match_pair.cpu().numpy().ravel() |
| | match_record[2] = self.counts[g0].sum().item() |
| | match_record[3] = self.counts[g1].sum().item() |
| | |
| | united.append(match_record) |
| | else: |
| | |
| | matches = matches[torch.isin(group_ids[matches[:,0]],g,invert=True) \ |
| | | torch.isin(group_ids[matches[:,1]],g,invert=True)] |
| |
|
| | |
| | p_bar.update(n_matches - matches.shape[0]) |
| | n_matches = matches.shape[0] |
| |
|
| | predicted_grouping = self.ids_to_group(group_ids) |
| |
|
| | if always_match is not None: |
| | predicted_grouping = predicted_grouping.unite(always_grouping) |
| |
|
| | if return_united: |
| | united_df = pd.DataFrame(np.vstack(united),columns=['i','j','n_i','n_j']) |
| | united_df = pd.merge(united_df,cos_scores_df,how='inner',on=['i','j']) |
| | united_df['score'] = self.score_model( |
| | torch.tensor(united_df['cos'].values).to(self.device) |
| | ).cpu().numpy() |
| | |
| | united_df = united_df.drop('cos',axis=1) |
| | |
| | for c in ['i','j']: |
| | united_df[c] = [self.strings[i] for i in united_df[c]] |
| |
|
| | if always_match is not None: |
| | united_df['always_match'] = [always_grouping[i] == always_grouping[j] |
| | for i,j in united_df[['i','j']].values] |
| |
|
| | return predicted_grouping,united_df |
| | |
| | else: |
| |
|
| | return predicted_grouping |
| |
|
| | @torch.no_grad() |
| | def unite_nearest(self,target_strings,threshold=0,always_grouping=None,progress_bar=True,batch_size=64): |
| | """ |
| | Unite embedding strings with each string's most similar target string. |
| | |
| | - "always_grouping" will be used to inialize the group_ids before uniting new matches |
| | - "theshold" sets the minimimum match similarity required between a string and target string |
| | for the string to be matched. (i.e., setting theshold=0 will result in every embedding |
| | string to be matched its nearest target string, while setting threshold=0.9 will leave |
| | strings that have similarity<0.9 with their nearest target string unaffected) |
| | |
| | returns: MatchGroups object |
| | """ |
| |
|
| | if always_grouping is not None: |
| | |
| | group_ids = self._group_to_ids(always_grouping) |
| | else: |
| | group_ids = torch.arange(len(self)).to(self.device) |
| |
|
| | V = self.V |
| | cos_threshold = self.score_model.score_to_cos(threshold) |
| |
|
| | seed_ids = torch.tensor([self.string_map[s] for s in target_strings]).to(self.device) |
| | V_seed = V[seed_ids] |
| | g_seed = group_ids[seed_ids] |
| | is_seed = torch.zeros(V.shape[0],dtype=torch.bool).to(self.device) |
| | is_seed[g_seed] = True |
| |
|
| | for batch_start in tqdm(range(0,len(self),batch_size), |
| | delay=1,desc='Predicting matches',disable=not progress_bar): |
| |
|
| | batch_slice = slice(batch_start,batch_start+batch_size) |
| |
|
| | batch_cos = V[batch_slice]@V_seed.T |
| |
|
| | max_cos,max_seed = torch.max(batch_cos,dim=1) |
| |
|
| | |
| | batch_i = torch.nonzero(max_cos > cos_threshold) |
| |
|
| | if len(batch_i): |
| | |
| | |
| | batch_i = batch_i[~is_seed[batch_slice][batch_i]] |
| |
|
| | if len(batch_i): |
| | |
| | i = batch_i + batch_start |
| |
|
| | |
| | group_ids[i] = g_seed[max_seed[batch_i]] |
| |
|
| | return self._ids_to_group(group_ids) |
| |
|
| | @torch.no_grad() |
| | def score_pairs(self,string_pairs,batch_size=64,progress_bar=True): |
| | string_pairs = np.array(string_pairs) |
| |
|
| | scores = [] |
| | for batch_start in tqdm(range(0,string_pairs.shape[0],batch_size),desc='Scoring pairs',disable=not progress_bar): |
| |
|
| | V0 = self[string_pairs[batch_start:batch_start+batch_size,0]].V |
| | V1 = self[string_pairs[batch_start:batch_start+batch_size,1]].V |
| |
|
| | batch_cos = (V0*V1).sum(dim=1).ravel() |
| | batch_scores = self.score_model(batch_cos) |
| | |
| | scores.append(batch_scores.cpu().numpy()) |
| |
|
| | return np.concatenate(scores) |
| |
|
| | @torch.no_grad() |
| | def _batch_scores(self,group_ids,batch_start,batch_size, |
| | is_match=None, |
| | min_score=None,max_score=None, |
| | min_loss=None,max_loss=None): |
| |
|
| | strings = self.strings |
| | V = self.V |
| | w = self.w |
| |
|
| | |
| | i_slice = slice(batch_start,batch_start+batch_size) |
| | j_slice = slice(batch_start+1,None) |
| |
|
| | X = V[i_slice]@V[j_slice].T |
| | Y = (group_ids[i_slice,None] == group_ids[None,j_slice]).float() |
| | if w is not None: |
| | W = w[i_slice,None]*w[None,j_slice] |
| | else: |
| | W = None |
| |
|
| | scores = self.score_model(X) |
| | loss = self.score_model.loss(X,Y,weights=W) |
| |
|
| | |
| | |
| | scores = torch.triu(scores) |
| |
|
| | |
| | if is_match is not None: |
| | if is_match: |
| | scores *= Y |
| | else: |
| | scores *= (1 - Y) |
| |
|
| | |
| | if min_score is not None: |
| | scores *= (scores >= min_score) |
| |
|
| | |
| | if max_score is not None: |
| | scores *= (scores <= max_score) |
| |
|
| | |
| | if min_loss is not None: |
| | scores *= (loss >= min_loss) |
| |
|
| | |
| | if max_loss is not None: |
| | scores *= (loss <= max_loss) |
| |
|
| | |
| | i,j = torch.nonzero(scores,as_tuple=True) |
| |
|
| | pairs = np.hstack([ |
| | strings[i.cpu().numpy() + batch_start][:,None], |
| | strings[j.cpu().numpy() + (batch_start + 1)][:,None] |
| | ]) |
| |
|
| | pair_groups = np.hstack([ |
| | strings[group_ids[i + batch_start].cpu().numpy()][:,None], |
| | strings[group_ids[j + (batch_start + 1)].cpu().numpy()][:,None] |
| | ]) |
| |
|
| | pair_scores = scores[i,j].cpu().numpy() |
| | pair_losses = loss[i,j].cpu().numpy() |
| |
|
| | return pairs,pair_groups,pair_scores,pair_losses |
| |
|
| | def iter_scores(self,grouping=None,batch_size=64,progress_bar=True,**kwargs): |
| |
|
| | if grouping is not None: |
| | self = self.embed(grouping) |
| | group_ids = self._group_to_ids(grouping) |
| | else: |
| | group_ids = torch.arange(len(self)).to(self.device) |
| |
|
| | for batch_start in tqdm(range(0,len(self),batch_size),desc='Scoring pairs',disable=not progress_bar): |
| | pairs,pair_groups,scores,losses = self._batch_scored_pairs(self,group_ids,batch_start,batch_size,**kwargs) |
| | for (s0,s1),(g0,g1),score,loss in zip(pairs,pair_groups,scores,losses): |
| | yield { |
| | 'string0':s0, |
| | 'string1':s1, |
| | 'group0':g0, |
| | 'group1':g1, |
| | 'score':score, |
| | 'loss':loss, |
| | } |
| |
|
| |
|
| | def load_embeddings(f): |
| | """ |
| | Load embeddings from custom zipped archive format |
| | """ |
| | with ZipFile(f,'r') as zip: |
| | score_model = pickle.loads(zip.read('score_model.pkl')) |
| | weighting_function = pickle.loads(zip.read('weighting_function.pkl')) |
| | strings_df = pd.read_csv(zip.open('strings.csv'),na_filter=False) |
| | V = np.load(zip.open('V.npy')) |
| |
|
| | return Embeddings( |
| | strings=strings_df['string'].values, |
| | counts=torch.tensor(strings_df['count'].values), |
| | score_model=score_model, |
| | weighting_function=weighting_function, |
| | V=torch.tensor(V) |
| | ) |
| |
|
| |
|