| import random |
| import pandas as pd |
| import numpy as np |
| from tqdm import tqdm |
| from copy import copy,deepcopy |
| from collections import Counter |
| import torch |
| from torch import nn |
| from torch.utils.data import DataLoader |
| from transformers import get_cosine_schedule_with_warmup,get_linear_schedule_with_warmup, logging |
| from transformers.modeling_utils import PreTrainedModel |
|
|
| from .match_groups import MatchGroups |
| from .scoring import score_predicted |
| from .scoring_model import SimilarityScore |
| from .embeddings import Embeddings |
| from .embedding_model import EmbeddingModel |
| from .configuration import SimilarityModelConfig |
| logging.set_verbosity_error() |
|
|
|
|
| class ExponentWeights(): |
| def __init__(self, config,**kwargs): |
| self.exponent = config.get("weighting_exponent", 0.5) |
|
|
| def __call__(self,counts): |
| return counts**self.exponent |
|
|
|
|
| class SimilarityModel(PreTrainedModel): |
| config_class = SimilarityModelConfig |
| """ |
| A combined embedding/scorer model that produces Embeddings objects |
| as its primary output. |
| |
| - train() jointly optimizes the embedding_model and score_model using |
| contrastive learning to learn from a training MatchGroups. |
| """ |
| def __init__(self, config, **kwargs): |
| super().__init__(config) |
| |
| self.embedding_model = EmbeddingModel(config.embedding_model_config, **kwargs) |
| self.score_model = SimilarityScore(config.score_model_config, **kwargs) |
| self.weighting_function = ExponentWeights(config.weighting_function_config, **kwargs) |
| |
| self.config = config |
| self.to(config.device) |
|
|
| def to(self,device): |
| super().to(device) |
| self.embedding_model.to(device) |
| self.score_model.to(device) |
| |
|
|
| def save(self,savefile): |
| torch.save({'metadata': self.config, 'state_dict': self.state_dict()}, savefile) |
|
|
| @torch.no_grad() |
| def embed(self,input,to=None,batch_size=64,progress_bar=True,**kwargs): |
| """ |
| Construct an Embeddings object from input strings or a MatchGroups |
| """ |
|
|
| if to is None: |
| to = self.device |
|
|
| if isinstance(input, MatchGroups): |
| strings = input.strings() |
| counts = torch.tensor([input.counts[s] for s in strings],device=self.device).float().to(to) |
|
|
| else: |
| strings = list(input) |
| counts = torch.ones(len(strings),device=self.device).float().to(to) |
|
|
| input_loader = DataLoader(strings,batch_size=batch_size,num_workers=0) |
|
|
| self.embedding_model.eval() |
|
|
| V = None |
| batch_start = 0 |
| with tqdm(total=len(strings),delay=1,desc='Embedding strings',disable=not progress_bar) as pbar: |
| for batch_strings in input_loader: |
|
|
| v = self.embedding_model(batch_strings).detach().to(to) |
|
|
| if V is None: |
| |
| |
| V = torch.empty(len(strings),v.shape[1],device=to,dtype=v.dtype) |
|
|
| V[batch_start:batch_start+len(batch_strings),:] = v |
|
|
| pbar.update(len(batch_strings)) |
| batch_start += len(batch_strings) |
|
|
| score_model = copy(self.score_model) |
| score_model.load_state_dict(self.score_model.state_dict()) |
| score_model.to(to) |
|
|
| weighting_function = deepcopy(self.weighting_function) |
|
|
| return Embeddings(strings=strings, |
| V=V.detach(), |
| counts=counts.detach(), |
| score_model=score_model, |
| weighting_function=weighting_function, |
| device=to) |
|
|
| def train(self,training_groupings,max_epochs=1,batch_size=8, |
| score_decay=0,regularization=0, |
| transformer_lr=1e-5,projection_lr=1e-5,score_lr=10,warmup_frac=0.1, |
| max_grad_norm=1,dropout=False, |
| validation_groupings=None,target='F1',restore_best=True,val_seed=None, |
| validation_interval=1000,early_stopping=True,early_stopping_patience=3, |
| verbose=False,progress_bar=True, |
| **kwargs): |
|
|
| """ |
| Train the embedding_model and score_model to predict match probabilities |
| using the training_groupings as a source of "correct" matches. |
| Training algorithm uses contrastive learning with hard-positive |
| and hard-negative mining to fine tune the embedding model to place |
| matched strings near to each other in embedding space, while |
| simulataneously calibrating the score_model to predict the match |
| probabilities as a function of cosine distance |
| """ |
|
|
| if validation_groupings is None: |
| early_stopping = False |
| restore_best = False |
|
|
| num_training_steps = max_epochs*len(training_groupings)//batch_size |
| num_warmup_steps = int(warmup_frac*num_training_steps) |
|
|
| if transformer_lr or projection_lr: |
| embedding_optimizer = self.embedding_model.config_optimizer(transformer_lr,projection_lr) |
| embedding_scheduler = get_cosine_schedule_with_warmup( |
| embedding_optimizer, |
| num_warmup_steps=num_warmup_steps, |
| num_training_steps=num_training_steps) |
| if score_lr: |
| score_optimizer = self.score_model.config_optimizer(score_lr) |
| score_scheduler = get_linear_schedule_with_warmup( |
| score_optimizer, |
| num_warmup_steps=num_warmup_steps, |
| num_training_steps=num_training_steps) |
|
|
| step = 0 |
| self.history = [] |
| self.val_scores = [] |
| for epoch in range(max_epochs): |
|
|
| global_embeddings = self.embed(training_groupings) |
|
|
| strings = global_embeddings.strings |
| V = global_embeddings.V |
| w = global_embeddings.w |
|
|
| groups = torch.tensor([global_embeddings.string_map[training_groupings[s]] for s in strings],device=self.device) |
|
|
| |
| if w is not None: |
| w = w/w.mean() |
|
|
| shuffled_ids = list(range(len(strings))) |
| random.shuffle(shuffled_ids) |
|
|
| if dropout: |
| self.embedding_model.train() |
| else: |
| self.embedding_model.eval() |
|
|
| for batch_start in tqdm(range(0,len(strings),batch_size),desc=f'training epoch {epoch}',disable=not progress_bar): |
|
|
| h = {'epoch':epoch,'step':step} |
|
|
| batch_i = shuffled_ids[batch_start:batch_start+batch_size] |
|
|
| |
| if len(batch_i) < batch_size: |
| batch_i = batch_i + shuffled_ids[:(batch_size-len(batch_i))] |
|
|
| """ |
| Find highest loss match for each batch string (global search) |
| |
| Note: If we compute V_i with dropout enabled, it will add noise |
| to the embeddings and prevent the same pairs from being selected |
| every time. |
| """ |
| V_i = self.embedding_model(strings[batch_i]) |
|
|
| |
| V[batch_i,:] = V_i.detach() |
|
|
| with torch.no_grad(): |
|
|
| global_X = V_i@V.T |
| global_Y = (groups[batch_i][:,None] == groups[None,:]).float() |
|
|
| if w is not None: |
| global_W = torch.outer(w[batch_i],w) |
| else: |
| global_W = None |
|
|
| |
| if score_lr: |
| |
| self.score_model.requires_grad_(True) |
|
|
| global_loss = self.score_model.loss(global_X,global_Y,weights=global_W,decay=score_decay) |
|
|
| score_optimizer.zero_grad() |
| global_loss.nanmean().backward() |
| torch.nn.utils.clip_grad_norm_(self.score_model.parameters(),max_norm=max_grad_norm) |
|
|
| score_optimizer.step() |
| score_scheduler.step() |
|
|
| h['score_lr'] = score_optimizer.param_groups[0]['lr'] |
| h['global_mean_cos'] = global_X.mean().item() |
| try: |
| h['score_alpha'] = self.score_model.alpha.item() |
| except: |
| pass |
|
|
| else: |
| with torch.no_grad(): |
| global_loss = self.score_model.loss(global_X,global_Y) |
|
|
| h['global_loss'] = global_loss.detach().nanmean().item() |
|
|
| |
| if (transformer_lr or projection_lr) and step <= num_warmup_steps + num_training_steps: |
|
|
| |
| self.score_model.requires_grad_(False) |
|
|
| |
| with torch.no_grad(): |
| batch_j = global_loss.argmax(dim=1).flatten() |
|
|
| if w is not None: |
| batch_W = torch.outer(w[batch_i],w[batch_j]) |
| else: |
| batch_W = None |
|
|
| |
| V_j = self.embedding_model(strings[batch_j.tolist()]) |
|
|
| |
| V[batch_j,:] = V_j.detach() |
|
|
| batch_X = V_i@V_j.T |
| batch_Y = (groups[batch_i][:,None] == groups[batch_j][None,:]).float() |
| h['batch_obs'] = len(batch_i)*len(batch_j) |
|
|
| batch_loss = self.score_model.loss(batch_X,batch_Y,weights=batch_W) |
|
|
| if regularization: |
| |
| gor_Y = (groups[batch_i][:,None] != groups[batch_i][None,:]).float() |
| gor_n = gor_Y.sum() |
| if gor_n > 1: |
| gor_X = (V_i@V_i.T)*gor_Y |
| gor_m1 = 0.5*gor_X.sum()/gor_n |
| gor_m2 = 0.5*(gor_X**2).sum()/gor_n |
| batch_loss += regularization*(gor_m1 + torch.clamp(gor_m2 - 1/self.embedding_model.d,min=0)) |
|
|
| h['batch_nan'] = torch.isnan(batch_loss.detach()).sum().item() |
|
|
| embedding_optimizer.zero_grad() |
| batch_loss.nanmean().backward() |
|
|
| torch.nn.utils.clip_grad_norm_(self.parameters(),max_norm=max_grad_norm) |
|
|
| embedding_optimizer.step() |
| embedding_scheduler.step() |
|
|
| h['transformer_lr'] = embedding_optimizer.param_groups[1]['lr'] |
| h['projection_lr'] = embedding_optimizer.param_groups[-1]['lr'] |
|
|
| |
| h['batch_loss'] = batch_loss.detach().mean().item() |
| h['batch_pos_target'] = batch_Y.detach().mean().item() |
|
|
| self.history.append(h) |
| step += 1 |
|
|
| if (validation_groupings is not None) and not (step % validation_interval): |
|
|
| validation = len(self.validation_scores) |
| val_scores = self.test(validation_groupings) |
| val_scores['step'] = step - 1 |
| val_scores['epoch'] = epoch |
| val_scores['validation'] = validation |
|
|
| self.validation_scores.append(val_scores) |
|
|
| |
| if verbose: |
| print(f'\nValidation results at step {step} (current epoch {epoch})') |
| for k,v in val_scores.items(): |
| print(f' {k}: {v:.4f}') |
|
|
| print(list(self.score_model.named_parameters())) |
|
|
| |
| if restore_best: |
| if val_scores[target] >= max(h[target] for h in self.validation_scores): |
| best_state = deepcopy({ |
| 'state_dict':self.state_dict(), |
| 'val_scores':val_scores |
| }) |
|
|
| if early_stopping and (validation - best_state['val_scores']['validation'] > early_stopping_patience): |
| print(f'Stopping training ({early_stopping_patience} validation checks since best validation score)') |
| break |
|
|
| if restore_best: |
| print(f"Restoring to best state (step {best_state['val_scores']['step']}):") |
| for k,v in best_state['val_scores'].items(): |
| print(f' {k}: {v:.4f}') |
|
|
| self.to('cpu') |
| self.load_state_dict(best_state['state_dict']) |
| self.to(self.device) |
|
|
| return pd.DataFrame(self.history) |
|
|
| def unite_similar(self,input,**kwargs): |
| embeddings = self.embed(input,**kwargs) |
| return embeddings.unite_similar(**kwargs) |
|
|
| def test(self,gold_groupings, threshold=0.5, **kwargs): |
| embeddings = self.embed(gold_groupings, **kwargs) |
|
|
| if (isinstance(threshold, float)): |
| predicted = embeddings.unite_similar(threshold=threshold, **kwargs) |
| scores = score_predicted(predicted, gold_groupings, use_counts=True) |
|
|
| return scores |
| |
| results = [] |
| for thres in threshold: |
| predicted = embeddings.unite_similar(threshold=thres, **kwargs) |
|
|
| scores = score_predicted(predicted, gold_groupings, use_counts=True) |
| scores["threshold"] = thres |
| results.append(scores) |
|
|
| |
| return results |
|
|
|
|
|
|
| def load_similarity_model(f,map_location='cpu',*args,**kwargs): |
| checkpoint = torch.load(f, map_location=map_location, **kwargs) |
| metadata = checkpoint['metadata'] |
| state_dict = checkpoint['state_dict'] |
|
|
| model = SimilarityModel(config=metadata) |
| model.load_state_dict(state_dict) |
|
|
| return model |
| |
|
|
| |
|
|
|
|
|
|