Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import sys | |
| import os | |
| import time | |
| import copy | |
| from peft import get_peft_model, LoraConfig, TaskType, AutoPeftModelForCausalLM | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| def from_pretrained(cls, model_name, kwargs, cache_dir): | |
| # use local model if it exists | |
| if "/" in model_name: | |
| local_path = os.path.join(cache_dir, model_name.split("/")[-1]) | |
| else: | |
| local_path = os.path.join(cache_dir, model_name) | |
| if os.path.exists(local_path): | |
| return cls.from_pretrained(local_path, **kwargs) | |
| return cls.from_pretrained(model_name, **kwargs, cache_dir=cache_dir, device_map='auto') | |
| class DiscrepancyEstimator(nn.Module): | |
| def __init__(self, | |
| scoring_model_name: str=None, | |
| reference_model_name: str=None, | |
| scoring_model: AutoModelForCausalLM=None, | |
| reference_model: AutoModelForCausalLM=None, | |
| scoring_tokenizer: AutoTokenizer=None, | |
| reference_tokenizer: AutoTokenizer=None, | |
| cache_dir: str=None, | |
| train_method: str='DDL', | |
| pretrained_ckpt: str=None, | |
| ): | |
| super().__init__() | |
| assert train_method in ['DDL', 'SPO'], 'train_method should be DDL or SPO.' | |
| self.train_method = train_method | |
| self.cache_dir = cache_dir | |
| if pretrained_ckpt is not None: | |
| self.load_pretrained(pretrained_ckpt) | |
| else: | |
| if scoring_model_name is not None: | |
| if 'gpt-j' in scoring_model_name or 'GPT-J' in scoring_model_name: | |
| model_kwargs = dict( | |
| torch_dtype=torch.float16, | |
| revision='float16' | |
| ) | |
| else: | |
| model_kwargs = {} | |
| self.scoring_model_name = scoring_model_name | |
| self.scoring_model = from_pretrained(AutoModelForCausalLM, | |
| scoring_model_name, | |
| cache_dir=cache_dir, | |
| kwargs=model_kwargs) | |
| self.scoring_tokenizer = from_pretrained(AutoTokenizer, | |
| scoring_model_name, | |
| kwargs={'padding_side': 'right', | |
| 'use_fast': True if 'facebook/opt-' not in scoring_model_name else False}, | |
| cache_dir=cache_dir,) | |
| else: | |
| if scoring_model is None or scoring_tokenizer is None: | |
| raise ValueError('You should provide scoring_model_name or scoring_model and scoring_tokenizer.') | |
| self.scoring_model = scoring_model | |
| self.scoring_tokenizer = scoring_tokenizer | |
| self.scoring_model_name = scoring_model.config._name_or_path | |
| if self.scoring_tokenizer.pad_token is None: | |
| self.scoring_tokenizer.pad_token = self.scoring_tokenizer.eos_token | |
| self.scoring_tokenizer.pad_token_id = self.scoring_tokenizer.eos_token_id | |
| if reference_model_name is not None: | |
| if 'gpt-j' in reference_model_name or 'GPT-J' in reference_model_name: | |
| model_kwargs = dict( | |
| torch_dtype=torch.float16, | |
| revision='float16' | |
| ) | |
| else: | |
| model_kwargs = {} | |
| self.reference_model = from_pretrained(AutoModelForCausalLM, | |
| reference_model_name, | |
| cache_dir=cache_dir, | |
| kwargs=model_kwargs) | |
| self.reference_tokenizer = from_pretrained(AutoTokenizer, | |
| reference_model_name, | |
| kwargs={'padding_side': 'right', | |
| 'use_fast': True if 'facebook/opt-' not in reference_model_name else False}, | |
| cache_dir=cache_dir,) | |
| self.reference_model_name = reference_model_name | |
| else: | |
| if reference_model is None and reference_tokenizer is None: | |
| if train_method == 'DDL': | |
| self.reference_model = None | |
| self.reference_tokenizer = None | |
| self.reference_model_name = None | |
| else: | |
| self.reference_model = copy.deepcopy(self.scoring_model) | |
| self.reference_tokenizer = self.scoring_tokenizer | |
| self.reference_model_name = self.reference_model.config._name_or_path | |
| elif reference_model is not None and reference_tokenizer is not None: | |
| self.reference_model = reference_model | |
| self.reference_tokenizer = reference_tokenizer | |
| self.reference_model_name = reference_model.config._name_or_path | |
| else: | |
| raise ValueError('You should provide reference_model and reference_tokenizer at the same time.') | |
| if self.reference_tokenizer is not None: | |
| if self.reference_tokenizer.pad_token is None: | |
| self.reference_tokenizer.pad_token = self.reference_tokenizer.eos_token | |
| self.reference_tokenizer.pad_token_id = self.reference_tokenizer.eos_token_id | |
| def add_lora_config(self, lora_config: LoraConfig): | |
| self.lora_config = lora_config | |
| self.scoring_model = get_peft_model(self.scoring_model, self.lora_config) | |
| def load_pretrained(self, load_directory, load_directory_ref=None): | |
| """ | |
| Load the model's state_dict from the specified directory. | |
| """ | |
| if 'gpt-j' in load_directory or 'GPT-J' in load_directory: | |
| model_kwargs = dict( | |
| torch_dtype=torch.float16, | |
| revision='float16' | |
| ) | |
| else: | |
| model_kwargs = {} | |
| self.scoring_model = AutoPeftModelForCausalLM.from_pretrained(load_directory, **model_kwargs) | |
| self.scoring_tokenizer = AutoTokenizer.from_pretrained(load_directory) | |
| self.scoring_model_name = self.scoring_model.config._name_or_path | |
| if load_directory_ref: | |
| self.reference_model = AutoModelForCausalLM.from_pretrained(load_directory_ref, **model_kwargs) | |
| self.reference_tokenizer = AutoTokenizer.from_pretrained(load_directory_ref) | |
| self.reference_model_name = self.reference_model.config._name_or_path | |
| else: | |
| self.reference_model = None | |
| self.reference_tokenizer = None | |
| self.reference_model_name = None | |
| if self.scoring_tokenizer.pad_token is None: | |
| self.scoring_tokenizer.pad_token = self.scoring_tokenizer.eos_token | |
| self.scoring_tokenizer.pad_token_id = self.scoring_tokenizer.eos_token_id | |
| if self.reference_tokenizer is not None: | |
| if self.reference_tokenizer.pad_token is None: | |
| self.reference_tokenizer.pad_token = self.reference_tokenizer.eos_token | |
| self.reference_tokenizer.pad_token_id = self.reference_tokenizer.eos_token_id | |
| def get_sampling_discrepancy_analytic(self, reference_logits, scoring_logits, labels, attention_mask): | |
| if reference_logits.size(-1) != scoring_logits.size(-1): | |
| vocab_size = min(reference_logits.size(-1), scoring_logits.size(-1)) | |
| reference_logits = reference_logits[:, :, :vocab_size] | |
| scoring_logits = scoring_logits[:, :, :vocab_size] | |
| labels = labels.unsqueeze(-1) if labels.ndim == scoring_logits.ndim - 1 else labels | |
| lprobs_score = torch.log_softmax(scoring_logits, dim=-1) | |
| probs_ref = torch.softmax(reference_logits, dim=-1) | |
| log_likelihood = lprobs_score.gather(dim=-1, index=labels).squeeze(-1) | |
| mean_ref = (probs_ref * lprobs_score).sum(dim=-1) | |
| var_ref = (probs_ref * torch.square(lprobs_score)).sum(dim=-1) - torch.square(mean_ref) | |
| mask = attention_mask[:, 1:].float() # [bsz, seq_len-1], 1 for non-pad, 0 for pad | |
| log_likelihood_sum = (log_likelihood * mask).sum(dim=-1) # [bsz], sum over non-pad tokens | |
| mean_ref_sum = (mean_ref * mask).sum(dim=-1) # [bsz], sum over non-pad tokens | |
| var_ref_sum = (var_ref * mask).sum(dim=-1) # [bsz], sum over non-pad tokens | |
| discrepancy = (log_likelihood_sum - mean_ref_sum) / (var_ref_sum.sqrt() + 1e-8) # [bsz], avoid division by zero | |
| return discrepancy, log_likelihood_sum | |
| def get_discrepancy_of_scoring_and_reference_models(self, | |
| input_ids_for_scoring_model, | |
| attention_mask_for_scoring_model, | |
| input_ids_for_reference_model=None, | |
| attention_mask_for_reference_model=None, | |
| ) -> dict: | |
| labels = input_ids_for_scoring_model[:, 1:] # shape: [bsz, sentence_len - 1] | |
| scoring_logits = self.scoring_model(input_ids_for_scoring_model, | |
| attention_mask=attention_mask_for_scoring_model).logits[:,:-1,:] | |
| if self.reference_model is not None: | |
| assert input_ids_for_reference_model is not None and attention_mask_for_reference_model is not None, \ | |
| "If reference_model is provided, you should provide reference_tokenizer to dataset initialization." | |
| with torch.no_grad(): | |
| # check if tokenizer is the match | |
| reference_labels = input_ids_for_reference_model[:, 1:] # shape: [bsz, sentence_len] | |
| assert torch.all(reference_labels == labels), \ | |
| "Tokenizer is mismatch." | |
| reference_logits = self.reference_model(input_ids_for_reference_model, | |
| attention_mask=attention_mask_for_reference_model).logits[:,:-1,:] | |
| else: | |
| reference_logits = scoring_logits | |
| if self.reference_model is not None: | |
| discrepancy_ref, logprob_ref = self.get_sampling_discrepancy_analytic(reference_logits, reference_logits, | |
| labels, attention_mask=attention_mask_for_reference_model) | |
| else: | |
| discrepancy_ref, logprob_ref = None, None | |
| discrepancy_score, logprob_score = self.get_sampling_discrepancy_analytic(reference_logits, scoring_logits, | |
| labels, attention_mask=attention_mask_for_scoring_model) | |
| return { | |
| 'scoring_discrepancy': discrepancy_score, | |
| 'scoring_logprob': logprob_score, | |
| 'reference_discrepancy': discrepancy_ref, | |
| 'reference_logprob': logprob_ref, | |
| } | |
| def forward(self, | |
| scoring_original_input_ids, | |
| scoring_original_attention_mask, | |
| scoring_rewritten_input_ids, | |
| scoring_rewritten_attention_mask, | |
| reference_original_input_ids=None, | |
| reference_original_attention_mask=None, | |
| reference_rewritten_input_ids=None, | |
| reference_rewritten_attention_mask=None, | |
| ) -> dict: | |
| if self.train_method == 'SPO': | |
| assert reference_original_input_ids is not None and reference_original_attention_mask is not None, \ | |
| "If train_method is SPO, you should provide reference_original_input_ids and reference_original_attention_mask." | |
| assert reference_rewritten_input_ids is not None and reference_rewritten_attention_mask is not None, \ | |
| "If train_method is SPO, you should provide reference_rewritten_input_ids and reference_rewritten_attention_mask." | |
| elif self.train_method == 'DDL': | |
| assert reference_original_input_ids is None and reference_original_attention_mask is None, \ | |
| "If train_method is DDL, you should not provide reference_original_input_ids and reference_original_attention_mask." | |
| assert reference_rewritten_input_ids is None and reference_rewritten_attention_mask is None, \ | |
| "If train_method is DDL, you should not provide reference_rewritten_input_ids and reference_rewritten_attention_mask." | |
| else: | |
| raise ValueError('train_method should be DDL or SPO.') | |
| original_output = self.get_discrepancy_of_scoring_and_reference_models( | |
| input_ids_for_scoring_model=scoring_original_input_ids, | |
| attention_mask_for_scoring_model=scoring_original_attention_mask, | |
| input_ids_for_reference_model=reference_original_input_ids, | |
| attention_mask_for_reference_model=reference_original_attention_mask, | |
| ) | |
| rewritten_output = self.get_discrepancy_of_scoring_and_reference_models( | |
| input_ids_for_scoring_model=scoring_rewritten_input_ids, | |
| attention_mask_for_scoring_model=scoring_rewritten_attention_mask, | |
| input_ids_for_reference_model=reference_rewritten_input_ids, | |
| attention_mask_for_reference_model=reference_rewritten_attention_mask, | |
| ) | |
| return { | |
| 'scoring_original_discrepancy': original_output['scoring_discrepancy'], | |
| 'scoring_original_logprob': original_output['scoring_logprob'], | |
| 'scoring_rewritten_discrepancy': rewritten_output['scoring_discrepancy'], | |
| 'scoring_rewritten_logprob': rewritten_output['scoring_logprob'], | |
| 'reference_original_discrepancy': original_output['reference_discrepancy'], | |
| 'reference_original_logprob': original_output['reference_logprob'], | |
| 'reference_rewritten_discrepancy': rewritten_output['reference_discrepancy'], | |
| 'reference_rewritten_logprob': rewritten_output['reference_logprob'], | |
| } |