DetectAnyLLM / core /model.py
JiachenFu's picture
update: fix some bug
2db4c26
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
import time
import copy
from peft import get_peft_model, LoraConfig, TaskType, AutoPeftModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
def from_pretrained(cls, model_name, kwargs, cache_dir):
# use local model if it exists
if "/" in model_name:
local_path = os.path.join(cache_dir, model_name.split("/")[-1])
else:
local_path = os.path.join(cache_dir, model_name)
if os.path.exists(local_path):
return cls.from_pretrained(local_path, **kwargs)
return cls.from_pretrained(model_name, **kwargs, cache_dir=cache_dir, device_map='auto')
class DiscrepancyEstimator(nn.Module):
def __init__(self,
scoring_model_name: str=None,
reference_model_name: str=None,
scoring_model: AutoModelForCausalLM=None,
reference_model: AutoModelForCausalLM=None,
scoring_tokenizer: AutoTokenizer=None,
reference_tokenizer: AutoTokenizer=None,
cache_dir: str=None,
train_method: str='DDL',
pretrained_ckpt: str=None,
):
super().__init__()
assert train_method in ['DDL', 'SPO'], 'train_method should be DDL or SPO.'
self.train_method = train_method
self.cache_dir = cache_dir
if pretrained_ckpt is not None:
self.load_pretrained(pretrained_ckpt)
else:
if scoring_model_name is not None:
if 'gpt-j' in scoring_model_name or 'GPT-J' in scoring_model_name:
model_kwargs = dict(
torch_dtype=torch.float16,
revision='float16'
)
else:
model_kwargs = {}
self.scoring_model_name = scoring_model_name
self.scoring_model = from_pretrained(AutoModelForCausalLM,
scoring_model_name,
cache_dir=cache_dir,
kwargs=model_kwargs)
self.scoring_tokenizer = from_pretrained(AutoTokenizer,
scoring_model_name,
kwargs={'padding_side': 'right',
'use_fast': True if 'facebook/opt-' not in scoring_model_name else False},
cache_dir=cache_dir,)
else:
if scoring_model is None or scoring_tokenizer is None:
raise ValueError('You should provide scoring_model_name or scoring_model and scoring_tokenizer.')
self.scoring_model = scoring_model
self.scoring_tokenizer = scoring_tokenizer
self.scoring_model_name = scoring_model.config._name_or_path
if self.scoring_tokenizer.pad_token is None:
self.scoring_tokenizer.pad_token = self.scoring_tokenizer.eos_token
self.scoring_tokenizer.pad_token_id = self.scoring_tokenizer.eos_token_id
if reference_model_name is not None:
if 'gpt-j' in reference_model_name or 'GPT-J' in reference_model_name:
model_kwargs = dict(
torch_dtype=torch.float16,
revision='float16'
)
else:
model_kwargs = {}
self.reference_model = from_pretrained(AutoModelForCausalLM,
reference_model_name,
cache_dir=cache_dir,
kwargs=model_kwargs)
self.reference_tokenizer = from_pretrained(AutoTokenizer,
reference_model_name,
kwargs={'padding_side': 'right',
'use_fast': True if 'facebook/opt-' not in reference_model_name else False},
cache_dir=cache_dir,)
self.reference_model_name = reference_model_name
else:
if reference_model is None and reference_tokenizer is None:
if train_method == 'DDL':
self.reference_model = None
self.reference_tokenizer = None
self.reference_model_name = None
else:
self.reference_model = copy.deepcopy(self.scoring_model)
self.reference_tokenizer = self.scoring_tokenizer
self.reference_model_name = self.reference_model.config._name_or_path
elif reference_model is not None and reference_tokenizer is not None:
self.reference_model = reference_model
self.reference_tokenizer = reference_tokenizer
self.reference_model_name = reference_model.config._name_or_path
else:
raise ValueError('You should provide reference_model and reference_tokenizer at the same time.')
if self.reference_tokenizer is not None:
if self.reference_tokenizer.pad_token is None:
self.reference_tokenizer.pad_token = self.reference_tokenizer.eos_token
self.reference_tokenizer.pad_token_id = self.reference_tokenizer.eos_token_id
def add_lora_config(self, lora_config: LoraConfig):
self.lora_config = lora_config
self.scoring_model = get_peft_model(self.scoring_model, self.lora_config)
def load_pretrained(self, load_directory, load_directory_ref=None):
"""
Load the model's state_dict from the specified directory.
"""
if 'gpt-j' in load_directory or 'GPT-J' in load_directory:
model_kwargs = dict(
torch_dtype=torch.float16,
revision='float16'
)
else:
model_kwargs = {}
self.scoring_model = AutoPeftModelForCausalLM.from_pretrained(load_directory, **model_kwargs)
self.scoring_tokenizer = AutoTokenizer.from_pretrained(load_directory)
self.scoring_model_name = self.scoring_model.config._name_or_path
if load_directory_ref:
self.reference_model = AutoModelForCausalLM.from_pretrained(load_directory_ref, **model_kwargs)
self.reference_tokenizer = AutoTokenizer.from_pretrained(load_directory_ref)
self.reference_model_name = self.reference_model.config._name_or_path
else:
self.reference_model = None
self.reference_tokenizer = None
self.reference_model_name = None
if self.scoring_tokenizer.pad_token is None:
self.scoring_tokenizer.pad_token = self.scoring_tokenizer.eos_token
self.scoring_tokenizer.pad_token_id = self.scoring_tokenizer.eos_token_id
if self.reference_tokenizer is not None:
if self.reference_tokenizer.pad_token is None:
self.reference_tokenizer.pad_token = self.reference_tokenizer.eos_token
self.reference_tokenizer.pad_token_id = self.reference_tokenizer.eos_token_id
def get_sampling_discrepancy_analytic(self, reference_logits, scoring_logits, labels, attention_mask):
if reference_logits.size(-1) != scoring_logits.size(-1):
vocab_size = min(reference_logits.size(-1), scoring_logits.size(-1))
reference_logits = reference_logits[:, :, :vocab_size]
scoring_logits = scoring_logits[:, :, :vocab_size]
labels = labels.unsqueeze(-1) if labels.ndim == scoring_logits.ndim - 1 else labels
lprobs_score = torch.log_softmax(scoring_logits, dim=-1)
probs_ref = torch.softmax(reference_logits, dim=-1)
log_likelihood = lprobs_score.gather(dim=-1, index=labels).squeeze(-1)
mean_ref = (probs_ref * lprobs_score).sum(dim=-1)
var_ref = (probs_ref * torch.square(lprobs_score)).sum(dim=-1) - torch.square(mean_ref)
mask = attention_mask[:, 1:].float() # [bsz, seq_len-1], 1 for non-pad, 0 for pad
log_likelihood_sum = (log_likelihood * mask).sum(dim=-1) # [bsz], sum over non-pad tokens
mean_ref_sum = (mean_ref * mask).sum(dim=-1) # [bsz], sum over non-pad tokens
var_ref_sum = (var_ref * mask).sum(dim=-1) # [bsz], sum over non-pad tokens
discrepancy = (log_likelihood_sum - mean_ref_sum) / (var_ref_sum.sqrt() + 1e-8) # [bsz], avoid division by zero
return discrepancy, log_likelihood_sum
def get_discrepancy_of_scoring_and_reference_models(self,
input_ids_for_scoring_model,
attention_mask_for_scoring_model,
input_ids_for_reference_model=None,
attention_mask_for_reference_model=None,
) -> dict:
labels = input_ids_for_scoring_model[:, 1:] # shape: [bsz, sentence_len - 1]
scoring_logits = self.scoring_model(input_ids_for_scoring_model,
attention_mask=attention_mask_for_scoring_model).logits[:,:-1,:]
if self.reference_model is not None:
assert input_ids_for_reference_model is not None and attention_mask_for_reference_model is not None, \
"If reference_model is provided, you should provide reference_tokenizer to dataset initialization."
with torch.no_grad():
# check if tokenizer is the match
reference_labels = input_ids_for_reference_model[:, 1:] # shape: [bsz, sentence_len]
assert torch.all(reference_labels == labels), \
"Tokenizer is mismatch."
reference_logits = self.reference_model(input_ids_for_reference_model,
attention_mask=attention_mask_for_reference_model).logits[:,:-1,:]
else:
reference_logits = scoring_logits
if self.reference_model is not None:
discrepancy_ref, logprob_ref = self.get_sampling_discrepancy_analytic(reference_logits, reference_logits,
labels, attention_mask=attention_mask_for_reference_model)
else:
discrepancy_ref, logprob_ref = None, None
discrepancy_score, logprob_score = self.get_sampling_discrepancy_analytic(reference_logits, scoring_logits,
labels, attention_mask=attention_mask_for_scoring_model)
return {
'scoring_discrepancy': discrepancy_score,
'scoring_logprob': logprob_score,
'reference_discrepancy': discrepancy_ref,
'reference_logprob': logprob_ref,
}
def forward(self,
scoring_original_input_ids,
scoring_original_attention_mask,
scoring_rewritten_input_ids,
scoring_rewritten_attention_mask,
reference_original_input_ids=None,
reference_original_attention_mask=None,
reference_rewritten_input_ids=None,
reference_rewritten_attention_mask=None,
) -> dict:
if self.train_method == 'SPO':
assert reference_original_input_ids is not None and reference_original_attention_mask is not None, \
"If train_method is SPO, you should provide reference_original_input_ids and reference_original_attention_mask."
assert reference_rewritten_input_ids is not None and reference_rewritten_attention_mask is not None, \
"If train_method is SPO, you should provide reference_rewritten_input_ids and reference_rewritten_attention_mask."
elif self.train_method == 'DDL':
assert reference_original_input_ids is None and reference_original_attention_mask is None, \
"If train_method is DDL, you should not provide reference_original_input_ids and reference_original_attention_mask."
assert reference_rewritten_input_ids is None and reference_rewritten_attention_mask is None, \
"If train_method is DDL, you should not provide reference_rewritten_input_ids and reference_rewritten_attention_mask."
else:
raise ValueError('train_method should be DDL or SPO.')
original_output = self.get_discrepancy_of_scoring_and_reference_models(
input_ids_for_scoring_model=scoring_original_input_ids,
attention_mask_for_scoring_model=scoring_original_attention_mask,
input_ids_for_reference_model=reference_original_input_ids,
attention_mask_for_reference_model=reference_original_attention_mask,
)
rewritten_output = self.get_discrepancy_of_scoring_and_reference_models(
input_ids_for_scoring_model=scoring_rewritten_input_ids,
attention_mask_for_scoring_model=scoring_rewritten_attention_mask,
input_ids_for_reference_model=reference_rewritten_input_ids,
attention_mask_for_reference_model=reference_rewritten_attention_mask,
)
return {
'scoring_original_discrepancy': original_output['scoring_discrepancy'],
'scoring_original_logprob': original_output['scoring_logprob'],
'scoring_rewritten_discrepancy': rewritten_output['scoring_discrepancy'],
'scoring_rewritten_logprob': rewritten_output['scoring_logprob'],
'reference_original_discrepancy': original_output['reference_discrepancy'],
'reference_original_logprob': original_output['reference_logprob'],
'reference_rewritten_discrepancy': rewritten_output['reference_discrepancy'],
'reference_rewritten_logprob': rewritten_output['reference_logprob'],
}