|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gc |
|
|
from typing import Any, List, Optional, Tuple |
|
|
|
|
|
import datasets |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
import tensorflow_datasets as tfds |
|
|
import torch |
|
|
import tqdm |
|
|
from huggingface_hub import HfApi, create_repo |
|
|
from huggingface_hub.utils import RepositoryNotFoundError |
|
|
from sklearn import model_selection |
|
|
|
|
|
import transformers |
|
|
|
|
|
|
|
|
def pad_to_len( |
|
|
arr: torch.Tensor, |
|
|
target_len: int, |
|
|
left_pad: bool, |
|
|
eos_token: int, |
|
|
device: torch.device, |
|
|
) -> torch.Tensor: |
|
|
"""Pad or truncate array to given length.""" |
|
|
if arr.shape[1] < target_len: |
|
|
shape_for_ones = list(arr.shape) |
|
|
shape_for_ones[1] = target_len - shape_for_ones[1] |
|
|
padded = ( |
|
|
torch.ones( |
|
|
shape_for_ones, |
|
|
device=device, |
|
|
dtype=torch.long, |
|
|
) |
|
|
* eos_token |
|
|
) |
|
|
if not left_pad: |
|
|
arr = torch.concatenate((arr, padded), dim=1) |
|
|
else: |
|
|
arr = torch.concatenate((padded, arr), dim=1) |
|
|
else: |
|
|
arr = arr[:, :target_len] |
|
|
return arr |
|
|
|
|
|
|
|
|
def filter_and_truncate( |
|
|
outputs: torch.Tensor, |
|
|
truncation_length: Optional[int], |
|
|
eos_token_mask: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
"""Filter and truncate outputs to given length. |
|
|
|
|
|
Args: |
|
|
outputs: output tensor of shape [batch_size, output_len] |
|
|
truncation_length: Length to truncate the final output. |
|
|
eos_token_mask: EOS token mask of shape [batch_size, output_len] |
|
|
|
|
|
Returns: |
|
|
output tensor of shape [batch_size, truncation_length]. |
|
|
""" |
|
|
if truncation_length: |
|
|
outputs = outputs[:, :truncation_length] |
|
|
truncation_mask = torch.sum(eos_token_mask, dim=1) >= truncation_length |
|
|
return outputs[truncation_mask, :] |
|
|
return outputs |
|
|
|
|
|
|
|
|
def process_outputs_for_training( |
|
|
all_outputs: List[torch.Tensor], |
|
|
logits_processor: transformers.generation.SynthIDTextWatermarkLogitsProcessor, |
|
|
tokenizer: Any, |
|
|
pos_truncation_length: Optional[int], |
|
|
neg_truncation_length: Optional[int], |
|
|
max_length: int, |
|
|
is_cv: bool, |
|
|
is_pos: bool, |
|
|
torch_device: torch.device, |
|
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: |
|
|
"""Process raw model outputs into format understandable by the detector. |
|
|
|
|
|
Args: |
|
|
all_outputs: sequence of outputs of shape [batch_size, output_len]. |
|
|
logits_processor: logits processor used for watermarking. |
|
|
tokenizer: tokenizer used for the model. |
|
|
pos_truncation_length: Length to truncate wm outputs. |
|
|
neg_truncation_length: Length to truncate uwm outputs. |
|
|
max_length: Length to pad truncated outputs so that all processed entries. |
|
|
have same shape. |
|
|
is_cv: Process given outputs for cross validation. |
|
|
is_pos: Process given outputs for positives. |
|
|
torch_device: torch device to use. |
|
|
|
|
|
Returns: |
|
|
Tuple of |
|
|
all_masks: list of masks of shape [batch_size, max_length]. |
|
|
all_g_values: list of g_values of shape [batch_size, max_length, depth]. |
|
|
""" |
|
|
all_masks = [] |
|
|
all_g_values = [] |
|
|
for outputs in tqdm.tqdm(all_outputs): |
|
|
|
|
|
|
|
|
eos_token_mask = logits_processor.compute_eos_token_mask( |
|
|
input_ids=outputs, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
if is_pos or is_cv: |
|
|
|
|
|
|
|
|
outputs = filter_and_truncate(outputs, pos_truncation_length, eos_token_mask) |
|
|
elif not is_pos and not is_cv: |
|
|
outputs = filter_and_truncate(outputs, neg_truncation_length, eos_token_mask) |
|
|
|
|
|
|
|
|
if outputs.shape[0] == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
outputs = pad_to_len(outputs, max_length, False, tokenizer.eos_token_id, torch_device) |
|
|
|
|
|
|
|
|
eos_token_mask = logits_processor.compute_eos_token_mask( |
|
|
input_ids=outputs, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
context_repetition_mask = logits_processor.compute_context_repetition_mask( |
|
|
input_ids=outputs, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
context_repetition_mask = pad_to_len(context_repetition_mask, max_length, True, 0, torch_device) |
|
|
|
|
|
|
|
|
combined_mask = context_repetition_mask * eos_token_mask |
|
|
|
|
|
g_values = logits_processor.compute_g_values( |
|
|
input_ids=outputs, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
g_values = pad_to_len(g_values, max_length, True, 0, torch_device) |
|
|
|
|
|
|
|
|
|
|
|
all_masks.append(combined_mask) |
|
|
all_g_values.append(g_values) |
|
|
return all_masks, all_g_values |
|
|
|
|
|
|
|
|
def tpr_at_fpr(detector, detector_inputs, w_true, minibatch_size, target_fpr=0.01) -> torch.Tensor: |
|
|
"""Calculates true positive rate (TPR) at false positive rate (FPR)=target_fpr.""" |
|
|
positive_idxs = w_true == 1 |
|
|
negative_idxs = w_true == 0 |
|
|
num_samples = detector_inputs[0].size(0) |
|
|
|
|
|
w_preds = [] |
|
|
for start in range(0, num_samples, minibatch_size): |
|
|
end = start + minibatch_size |
|
|
detector_inputs_ = ( |
|
|
detector_inputs[0][start:end], |
|
|
detector_inputs[1][start:end], |
|
|
) |
|
|
with torch.no_grad(): |
|
|
w_pred = detector(*detector_inputs_)[0] |
|
|
w_preds.append(w_pred) |
|
|
|
|
|
w_pred = torch.cat(w_preds, dim=0) |
|
|
positive_scores = w_pred[positive_idxs] |
|
|
negative_scores = w_pred[negative_idxs] |
|
|
|
|
|
|
|
|
|
|
|
fpr_threshold = torch.quantile(negative_scores, 1 - target_fpr) |
|
|
|
|
|
return torch.mean((positive_scores >= fpr_threshold).to(dtype=torch.float32)).item() |
|
|
|
|
|
|
|
|
def update_fn_if_fpr_tpr(detector, g_values_val, mask_val, watermarked_val, minibatch_size): |
|
|
"""Loss function for negative TPR@FPR=1% as the validation loss.""" |
|
|
tpr_ = tpr_at_fpr( |
|
|
detector=detector, |
|
|
detector_inputs=(g_values_val, mask_val), |
|
|
w_true=watermarked_val, |
|
|
minibatch_size=minibatch_size, |
|
|
) |
|
|
return -tpr_ |
|
|
|
|
|
|
|
|
def process_raw_model_outputs( |
|
|
logits_processor, |
|
|
tokenizer, |
|
|
pos_truncation_length, |
|
|
neg_truncation_length, |
|
|
max_padded_length, |
|
|
tokenized_wm_outputs, |
|
|
test_size, |
|
|
tokenized_uwm_outputs, |
|
|
torch_device, |
|
|
): |
|
|
|
|
|
train_wm_outputs, cv_wm_outputs = model_selection.train_test_split(tokenized_wm_outputs, test_size=test_size) |
|
|
|
|
|
train_uwm_outputs, cv_uwm_outputs = model_selection.train_test_split(tokenized_uwm_outputs, test_size=test_size) |
|
|
|
|
|
process_kwargs = { |
|
|
"logits_processor": logits_processor, |
|
|
"tokenizer": tokenizer, |
|
|
"pos_truncation_length": pos_truncation_length, |
|
|
"neg_truncation_length": neg_truncation_length, |
|
|
"max_length": max_padded_length, |
|
|
"torch_device": torch_device, |
|
|
} |
|
|
|
|
|
|
|
|
wm_masks_train, wm_g_values_train = process_outputs_for_training( |
|
|
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in train_wm_outputs], |
|
|
is_pos=True, |
|
|
is_cv=False, |
|
|
**process_kwargs, |
|
|
) |
|
|
wm_masks_cv, wm_g_values_cv = process_outputs_for_training( |
|
|
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in cv_wm_outputs], |
|
|
is_pos=True, |
|
|
is_cv=True, |
|
|
**process_kwargs, |
|
|
) |
|
|
uwm_masks_train, uwm_g_values_train = process_outputs_for_training( |
|
|
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in train_uwm_outputs], |
|
|
is_pos=False, |
|
|
is_cv=False, |
|
|
**process_kwargs, |
|
|
) |
|
|
uwm_masks_cv, uwm_g_values_cv = process_outputs_for_training( |
|
|
[torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in cv_uwm_outputs], |
|
|
is_pos=False, |
|
|
is_cv=True, |
|
|
**process_kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
def pack(mask, g_values): |
|
|
mask = torch.cat(mask, dim=0) |
|
|
g = torch.cat(g_values, dim=0) |
|
|
return mask, g |
|
|
|
|
|
wm_masks_train, wm_g_values_train = pack(wm_masks_train, wm_g_values_train) |
|
|
|
|
|
wm_labels_train = torch.ones((wm_masks_train.shape[0],), dtype=torch.float, device=torch_device) |
|
|
|
|
|
wm_masks_cv, wm_g_values_cv = pack(wm_masks_cv, wm_g_values_cv) |
|
|
wm_labels_cv = torch.ones((wm_masks_cv.shape[0],), dtype=torch.float, device=torch_device) |
|
|
|
|
|
uwm_masks_train, uwm_g_values_train = pack(uwm_masks_train, uwm_g_values_train) |
|
|
uwm_labels_train = torch.zeros((uwm_masks_train.shape[0],), dtype=torch.float, device=torch_device) |
|
|
|
|
|
uwm_masks_cv, uwm_g_values_cv = pack(uwm_masks_cv, uwm_g_values_cv) |
|
|
uwm_labels_cv = torch.zeros((uwm_masks_cv.shape[0],), dtype=torch.float, device=torch_device) |
|
|
|
|
|
|
|
|
train_g_values = torch.cat((wm_g_values_train, uwm_g_values_train), dim=0).squeeze() |
|
|
train_labels = torch.cat((wm_labels_train, uwm_labels_train), axis=0).squeeze() |
|
|
train_masks = torch.cat((wm_masks_train, uwm_masks_train), axis=0).squeeze() |
|
|
|
|
|
cv_g_values = torch.cat((wm_g_values_cv, uwm_g_values_cv), axis=0).squeeze() |
|
|
cv_labels = torch.cat((wm_labels_cv, uwm_labels_cv), axis=0).squeeze() |
|
|
cv_masks = torch.cat((wm_masks_cv, uwm_masks_cv), axis=0).squeeze() |
|
|
|
|
|
|
|
|
shuffled_idx = torch.randperm(train_g_values.shape[0]) |
|
|
|
|
|
train_g_values = train_g_values[shuffled_idx] |
|
|
train_labels = train_labels[shuffled_idx] |
|
|
train_masks = train_masks[shuffled_idx] |
|
|
|
|
|
|
|
|
shuffled_idx_cv = torch.randperm(cv_g_values.shape[0]) |
|
|
cv_g_values = cv_g_values[shuffled_idx_cv] |
|
|
cv_labels = cv_labels[shuffled_idx_cv] |
|
|
cv_masks = cv_masks[shuffled_idx_cv] |
|
|
|
|
|
|
|
|
del ( |
|
|
wm_g_values_train, |
|
|
wm_labels_train, |
|
|
wm_masks_train, |
|
|
wm_g_values_cv, |
|
|
wm_labels_cv, |
|
|
wm_masks_cv, |
|
|
) |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return train_g_values, train_masks, train_labels, cv_g_values, cv_masks, cv_labels |
|
|
|
|
|
|
|
|
def get_tokenized_uwm_outputs(num_negatives, neg_batch_size, tokenizer, device): |
|
|
dataset, info = tfds.load("wikipedia/20230601.en", split="train", with_info=True) |
|
|
dataset = dataset.take(num_negatives) |
|
|
|
|
|
|
|
|
df = tfds.as_dataframe(dataset, info) |
|
|
ds = tf.data.Dataset.from_tensor_slices(dict(df)) |
|
|
tf.random.set_seed(0) |
|
|
ds = ds.shuffle(buffer_size=10_000) |
|
|
ds = ds.batch(batch_size=neg_batch_size) |
|
|
|
|
|
tokenized_uwm_outputs = [] |
|
|
|
|
|
padded_length = 1000 |
|
|
for i, batch in tqdm.tqdm(enumerate(ds)): |
|
|
responses = [val.decode() for val in batch["text"].numpy()] |
|
|
inputs = tokenizer( |
|
|
responses, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
).to(device) |
|
|
inputs = inputs["input_ids"].cpu().numpy() |
|
|
if inputs.shape[1] >= padded_length: |
|
|
inputs = inputs[:, :padded_length] |
|
|
else: |
|
|
inputs = np.concatenate( |
|
|
[inputs, np.ones((neg_batch_size, padded_length - inputs.shape[1])) * tokenizer.eos_token_id], axis=1 |
|
|
) |
|
|
tokenized_uwm_outputs.append(inputs) |
|
|
if len(tokenized_uwm_outputs) * neg_batch_size > num_negatives: |
|
|
break |
|
|
return tokenized_uwm_outputs |
|
|
|
|
|
|
|
|
def get_tokenized_wm_outputs( |
|
|
model, |
|
|
tokenizer, |
|
|
watermark_config, |
|
|
num_pos_batches, |
|
|
pos_batch_size, |
|
|
temperature, |
|
|
max_output_len, |
|
|
top_k, |
|
|
top_p, |
|
|
device, |
|
|
): |
|
|
eli5_prompts = datasets.load_dataset("Pavithree/eli5") |
|
|
|
|
|
wm_outputs = [] |
|
|
|
|
|
for batch_id in tqdm.tqdm(range(num_pos_batches)): |
|
|
prompts = eli5_prompts["train"]["title"][batch_id * pos_batch_size : (batch_id + 1) * pos_batch_size] |
|
|
prompts = [prompt.strip('"') for prompt in prompts] |
|
|
inputs = tokenizer( |
|
|
prompts, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
).to(device) |
|
|
_, inputs_len = inputs["input_ids"].shape |
|
|
|
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
watermarking_config=watermark_config, |
|
|
do_sample=True, |
|
|
max_length=inputs_len + max_output_len, |
|
|
temperature=temperature, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
) |
|
|
|
|
|
wm_outputs.append(outputs[:, inputs_len:].cpu().detach()) |
|
|
|
|
|
del outputs, inputs, prompts |
|
|
gc.collect() |
|
|
|
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
return wm_outputs |
|
|
|
|
|
|
|
|
def upload_model_to_hf(model, hf_repo_name: str, private: bool = True): |
|
|
api = HfApi() |
|
|
|
|
|
|
|
|
try: |
|
|
api.repo_info(repo_id=hf_repo_name, use_auth_token=True) |
|
|
print(f"Repository '{hf_repo_name}' already exists.") |
|
|
except RepositoryNotFoundError: |
|
|
|
|
|
print(f"Repository '{hf_repo_name}' not found. Creating it...") |
|
|
create_repo(repo_id=hf_repo_name, private=private, use_auth_token=True) |
|
|
print(f"Repository '{hf_repo_name}' created successfully.") |
|
|
|
|
|
|
|
|
print(f"Uploading model to Hugging Face repo '{hf_repo_name}'...") |
|
|
model.push_to_hub(repo_id=hf_repo_name, use_auth_token=True) |
|
|
|