| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Example command with bag of words: |
| python run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95 |
| |
| Example command with discriminator: |
| python run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95 |
| """ |
|
|
| import argparse |
| import json |
| from operator import add |
| from typing import List, Optional, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| from pplm_classification_head import ClassificationHead |
| from torch import nn |
| from tqdm import trange |
|
|
| from transformers import GPT2LMHeadModel, GPT2Tokenizer |
| from transformers.file_utils import cached_path |
|
|
|
|
| PPLM_BOW = 1 |
| PPLM_DISCRIM = 2 |
| PPLM_BOW_DISCRIM = 3 |
| SMALL_CONST = 1e-15 |
| BIG_CONST = 1e10 |
|
|
| BAG_OF_WORDS_ARCHIVE_MAP = { |
| "legal": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt", |
| "military": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt", |
| "politics": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt", |
| "religion": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt", |
| "science": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt", |
| "space": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt", |
| "technology": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt", |
| } |
|
|
| DISCRIMINATOR_MODELS_PARAMS = { |
| "clickbait": { |
| "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/clickbait_classifier_head.pt", |
| "class_size": 2, |
| "embed_size": 1024, |
| "class_vocab": {"non_clickbait": 0, "clickbait": 1}, |
| "default_class": 1, |
| "pretrained_model": "openai-community/gpt2-medium", |
| }, |
| "sentiment": { |
| "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/SST_classifier_head.pt", |
| "class_size": 5, |
| "embed_size": 1024, |
| "class_vocab": {"very_positive": 2, "very_negative": 3}, |
| "default_class": 3, |
| "pretrained_model": "openai-community/gpt2-medium", |
| }, |
| } |
|
|
|
|
| def top_k_filter(logits, k, probs=False): |
| """ |
| Masks everything but the k top entries as -infinity (1e10). |
| Used to mask logits such that e^-infinity -> 0 won't contribute to the |
| sum of the denominator. |
| """ |
| if k == 0: |
| return logits |
| else: |
| values = torch.topk(logits, k)[0] |
| batch_mins = values[:, -1].view(-1, 1).expand_as(logits) |
| if probs: |
| return torch.where(logits < batch_mins, torch.ones_like(logits) * 0.0, logits) |
| return torch.where(logits < batch_mins, torch.ones_like(logits) * -BIG_CONST, logits) |
|
|
|
|
| def perturb_past( |
| past, |
| model, |
| last, |
| unpert_past=None, |
| unpert_logits=None, |
| accumulated_hidden=None, |
| grad_norms=None, |
| stepsize=0.01, |
| one_hot_bows_vectors=None, |
| classifier=None, |
| class_label=None, |
| loss_type=0, |
| num_iterations=3, |
| horizon_length=1, |
| window_length=0, |
| decay=False, |
| gamma=1.5, |
| kl_scale=0.01, |
| device="cuda", |
| ): |
| |
| grad_accumulator = [(np.zeros(p.shape).astype("float32")) for p in past] |
|
|
| if accumulated_hidden is None: |
| accumulated_hidden = 0 |
|
|
| if decay: |
| decay_mask = torch.arange(0.0, 1.0 + SMALL_CONST, 1.0 / (window_length))[1:] |
| else: |
| decay_mask = 1.0 |
|
|
| |
| |
| _, _, _, curr_length, _ = past[0].shape |
|
|
| if curr_length > window_length and window_length > 0: |
| ones_key_val_shape = tuple(past[0].shape[:-2]) + (window_length,) + tuple(past[0].shape[-1:]) |
|
|
| zeros_key_val_shape = tuple(past[0].shape[:-2]) + (curr_length - window_length,) + tuple(past[0].shape[-1:]) |
|
|
| ones_mask = torch.ones(ones_key_val_shape) |
| ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3) |
| ones_mask = ones_mask.permute(0, 1, 2, 4, 3) |
|
|
| window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)), dim=-2).to(device) |
| else: |
| window_mask = torch.ones_like(past[0]).to(device) |
|
|
| |
| loss_per_iter = [] |
| new_accumulated_hidden = None |
| for i in range(num_iterations): |
| print("Iteration ", i + 1) |
| curr_perturbation = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator] |
| |
| for p_ in curr_perturbation: |
| p_.retain_grad() |
|
|
| |
| perturbed_past = list(map(add, past, curr_perturbation)) |
| _, _, _, curr_length, _ = curr_perturbation[0].shape |
| lm_output = model(last, past_key_values=perturbed_past) |
| all_logits, all_hidden = lm_output["logits"], lm_output["hidden_states"] |
| hidden = all_hidden[-1] |
| new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach() |
| |
| logits = all_logits[:, -1, :] |
| probs = nn.functional.softmax(logits, dim=-1) |
|
|
| loss = 0.0 |
| loss_list = [] |
| if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM: |
| for one_hot_bow in one_hot_bows_vectors: |
| bow_logits = torch.mm(probs, torch.t(one_hot_bow)) |
| bow_loss = -torch.log(torch.sum(bow_logits)) |
| loss += bow_loss |
| loss_list.append(bow_loss) |
| print(" pplm_bow_loss:", loss.data.cpu().numpy()) |
|
|
| if loss_type == 2 or loss_type == 3: |
| ce_loss = nn.CrossEntropyLoss() |
| |
| curr_unpert_past = unpert_past |
| curr_probs = torch.unsqueeze(probs, dim=1) |
| wte = model.resize_token_embeddings() |
| for _ in range(horizon_length): |
| inputs_embeds = torch.matmul(curr_probs, wte.weight.data) |
| lm_output = model(past_key_values=curr_unpert_past, inputs_embeds=inputs_embeds) |
| curr_all_logits, curr_unpert_past, curr_all_hidden = ( |
| lm_output["logits"], |
| lm_output["past_key_values"], |
| lm_output["hidden_states"], |
| ) |
| curr_logits = curr_all_logits[:, -1, :] |
| curr_probs = nn.functional.softmax(curr_logits, dim=-1) |
| curr_probs = torch.unsqueeze(curr_probs, dim=1) |
| curr_hidden = curr_all_hidden[-1] |
| new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1) |
|
|
| prediction = classifier(new_accumulated_hidden / (curr_length + 1 + horizon_length)) |
|
|
| label = torch.tensor(prediction.shape[0] * [class_label], device=device, dtype=torch.long) |
| discrim_loss = ce_loss(prediction, label) |
| print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy()) |
| loss += discrim_loss |
| loss_list.append(discrim_loss) |
|
|
| kl_loss = 0.0 |
| if kl_scale > 0.0: |
| unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1) |
| unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach() |
| correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach() |
| corrected_probs = probs + correction.detach() |
| kl_loss = kl_scale * ((corrected_probs * (corrected_probs / unpert_probs).log()).sum()) |
| print(" kl_loss", kl_loss.data.cpu().numpy()) |
| loss += kl_loss |
|
|
| loss_per_iter.append(loss.data.cpu().numpy()) |
| print(" pplm_loss", (loss - kl_loss).data.cpu().numpy()) |
|
|
| |
| loss.backward() |
|
|
| |
| if grad_norms is not None and loss_type == PPLM_BOW: |
| grad_norms = [ |
| torch.max(grad_norms[index], torch.norm(p_.grad * window_mask)) |
| for index, p_ in enumerate(curr_perturbation) |
| ] |
| else: |
| grad_norms = [ |
| (torch.norm(p_.grad * window_mask) + SMALL_CONST) for index, p_ in enumerate(curr_perturbation) |
| ] |
|
|
| |
| grad = [ |
| -stepsize * (p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy() |
| for index, p_ in enumerate(curr_perturbation) |
| ] |
|
|
| |
| grad_accumulator = list(map(add, grad, grad_accumulator)) |
|
|
| |
| for p_ in curr_perturbation: |
| p_.grad.data.zero_() |
|
|
| |
| new_past = [] |
| for p_ in past: |
| new_past.append(p_.detach()) |
| past = new_past |
|
|
| |
| grad_accumulator = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator] |
| pert_past = list(map(add, past, grad_accumulator)) |
|
|
| return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter |
|
|
|
|
| def get_classifier( |
| name: Optional[str], class_label: Union[str, int], device: str |
| ) -> Tuple[Optional[ClassificationHead], Optional[int]]: |
| if name is None: |
| return None, None |
|
|
| params = DISCRIMINATOR_MODELS_PARAMS[name] |
| classifier = ClassificationHead(class_size=params["class_size"], embed_size=params["embed_size"]).to(device) |
| if "url" in params: |
| resolved_archive_file = cached_path(params["url"]) |
| elif "path" in params: |
| resolved_archive_file = params["path"] |
| else: |
| raise ValueError("Either url or path have to be specified in the discriminator model parameters") |
| classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device)) |
| classifier.eval() |
|
|
| if isinstance(class_label, str): |
| if class_label in params["class_vocab"]: |
| label_id = params["class_vocab"][class_label] |
| else: |
| label_id = params["default_class"] |
| print("class_label {} not in class_vocab".format(class_label)) |
| print("available values are: {}".format(params["class_vocab"])) |
| print("using default class {}".format(label_id)) |
|
|
| elif isinstance(class_label, int): |
| if class_label in set(params["class_vocab"].values()): |
| label_id = class_label |
| else: |
| label_id = params["default_class"] |
| print("class_label {} not in class_vocab".format(class_label)) |
| print("available values are: {}".format(params["class_vocab"])) |
| print("using default class {}".format(label_id)) |
|
|
| else: |
| label_id = params["default_class"] |
|
|
| return classifier, label_id |
|
|
|
|
| def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> List[List[List[int]]]: |
| bow_indices = [] |
| for id_or_path in bag_of_words_ids_or_paths: |
| if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP: |
| filepath = cached_path(BAG_OF_WORDS_ARCHIVE_MAP[id_or_path]) |
| else: |
| filepath = id_or_path |
| with open(filepath, "r") as f: |
| words = f.read().strip().split("\n") |
| bow_indices.append([tokenizer.encode(word.strip(), add_prefix_space=True) for word in words]) |
| return bow_indices |
|
|
|
|
| def build_bows_one_hot_vectors(bow_indices, tokenizer, device="cuda"): |
| if bow_indices is None: |
| return None |
|
|
| one_hot_bows_vectors = [] |
| for single_bow in bow_indices: |
| single_bow = list(filter(lambda x: len(x) <= 1, single_bow)) |
| single_bow = torch.tensor(single_bow).to(device) |
| num_words = single_bow.shape[0] |
| one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device) |
| one_hot_bow.scatter_(1, single_bow, 1) |
| one_hot_bows_vectors.append(one_hot_bow) |
| return one_hot_bows_vectors |
|
|
|
|
| def full_text_generation( |
| model, |
| tokenizer, |
| context=None, |
| num_samples=1, |
| device="cuda", |
| bag_of_words=None, |
| discrim=None, |
| class_label=None, |
| length=100, |
| stepsize=0.02, |
| temperature=1.0, |
| top_k=10, |
| sample=False, |
| num_iterations=3, |
| grad_length=10000, |
| horizon_length=1, |
| window_length=0, |
| decay=False, |
| gamma=1.5, |
| gm_scale=0.9, |
| kl_scale=0.01, |
| repetition_penalty=1.0, |
| **kwargs, |
| ): |
| classifier, class_id = get_classifier(discrim, class_label, device) |
|
|
| bow_indices = [] |
| if bag_of_words: |
| bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), tokenizer) |
|
|
| if bag_of_words and classifier: |
| print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.") |
| loss_type = PPLM_BOW_DISCRIM |
|
|
| elif bag_of_words: |
| loss_type = PPLM_BOW |
| print("Using PPLM-BoW") |
|
|
| elif classifier is not None: |
| loss_type = PPLM_DISCRIM |
| print("Using PPLM-Discrim") |
|
|
| else: |
| raise Exception("Specify either a bag of words or a discriminator") |
|
|
| unpert_gen_tok_text, _, _ = generate_text_pplm( |
| model=model, |
| tokenizer=tokenizer, |
| context=context, |
| device=device, |
| length=length, |
| sample=sample, |
| perturb=False, |
| repetition_penalty=repetition_penalty, |
| ) |
| if device == "cuda": |
| torch.cuda.empty_cache() |
|
|
| pert_gen_tok_texts = [] |
| discrim_losses = [] |
| losses_in_time = [] |
|
|
| for i in range(num_samples): |
| pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm( |
| model=model, |
| tokenizer=tokenizer, |
| context=context, |
| device=device, |
| perturb=True, |
| bow_indices=bow_indices, |
| classifier=classifier, |
| class_label=class_id, |
| loss_type=loss_type, |
| length=length, |
| stepsize=stepsize, |
| temperature=temperature, |
| top_k=top_k, |
| sample=sample, |
| num_iterations=num_iterations, |
| grad_length=grad_length, |
| horizon_length=horizon_length, |
| window_length=window_length, |
| decay=decay, |
| gamma=gamma, |
| gm_scale=gm_scale, |
| kl_scale=kl_scale, |
| repetition_penalty=repetition_penalty, |
| ) |
| pert_gen_tok_texts.append(pert_gen_tok_text) |
| if classifier is not None: |
| discrim_losses.append(discrim_loss.data.cpu().numpy()) |
| losses_in_time.append(loss_in_time) |
|
|
| if device == "cuda": |
| torch.cuda.empty_cache() |
|
|
| return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time |
|
|
|
|
| def generate_text_pplm( |
| model, |
| tokenizer, |
| context=None, |
| past=None, |
| device="cuda", |
| perturb=True, |
| bow_indices=None, |
| classifier=None, |
| class_label=None, |
| loss_type=0, |
| length=100, |
| stepsize=0.02, |
| temperature=1.0, |
| top_k=10, |
| sample=False, |
| num_iterations=3, |
| grad_length=10000, |
| horizon_length=1, |
| window_length=0, |
| decay=False, |
| gamma=1.5, |
| gm_scale=0.9, |
| kl_scale=0.01, |
| repetition_penalty=1.0, |
| ): |
| output_so_far = None |
| if context: |
| context_t = torch.tensor(context, device=device, dtype=torch.long) |
| while len(context_t.shape) < 2: |
| context_t = context_t.unsqueeze(0) |
| output_so_far = context_t |
|
|
| |
| one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer, device) |
|
|
| grad_norms = None |
| last = None |
| unpert_discrim_loss = 0 |
| loss_in_time = [] |
| for i in trange(length, ascii=True): |
| |
| |
|
|
| |
| if past is None and output_so_far is not None: |
| last = output_so_far[:, -1:] |
| if output_so_far.shape[1] > 1: |
| past = model(output_so_far[:, :-1])["past_key_values"] |
|
|
| lm_output = model(output_so_far) |
| unpert_logits, unpert_past, unpert_all_hidden = ( |
| lm_output["logits"], |
| lm_output["past_key_values"], |
| lm_output["hidden_states"], |
| ) |
| unpert_last_hidden = unpert_all_hidden[-1] |
|
|
| |
| if i >= grad_length: |
| current_stepsize = stepsize * 0 |
| else: |
| current_stepsize = stepsize |
|
|
| |
| if not perturb or num_iterations == 0: |
| pert_past = past |
|
|
| else: |
| accumulated_hidden = unpert_last_hidden[:, :-1, :] |
| accumulated_hidden = torch.sum(accumulated_hidden, dim=1) |
|
|
| if past is not None: |
| pert_past, _, grad_norms, loss_this_iter = perturb_past( |
| past, |
| model, |
| last, |
| unpert_past=unpert_past, |
| unpert_logits=unpert_logits, |
| accumulated_hidden=accumulated_hidden, |
| grad_norms=grad_norms, |
| stepsize=current_stepsize, |
| one_hot_bows_vectors=one_hot_bows_vectors, |
| classifier=classifier, |
| class_label=class_label, |
| loss_type=loss_type, |
| num_iterations=num_iterations, |
| horizon_length=horizon_length, |
| window_length=window_length, |
| decay=decay, |
| gamma=gamma, |
| kl_scale=kl_scale, |
| device=device, |
| ) |
| loss_in_time.append(loss_this_iter) |
| else: |
| pert_past = past |
|
|
| lm_output = model(last, past_key_values=pert_past) |
| pert_logits, past = ( |
| lm_output["logits"], |
| lm_output["past_key_values"], |
| ) |
| pert_logits = pert_logits[:, -1, :] / temperature |
|
|
| for token_idx in set(output_so_far[0].tolist()): |
| if pert_logits[0, token_idx] < 0: |
| pert_logits[0, token_idx] *= repetition_penalty |
| else: |
| pert_logits[0, token_idx] /= repetition_penalty |
|
|
| pert_probs = nn.functional.softmax(pert_logits, dim=-1) |
|
|
| if classifier is not None: |
| ce_loss = nn.CrossEntropyLoss() |
| prediction = classifier(torch.mean(unpert_last_hidden, dim=1)) |
| label = torch.tensor([class_label], device=device, dtype=torch.long) |
| unpert_discrim_loss = ce_loss(prediction, label) |
| print("unperturbed discrim loss", unpert_discrim_loss.data.cpu().numpy()) |
| else: |
| unpert_discrim_loss = 0 |
|
|
| |
| if perturb: |
| unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1) |
|
|
| pert_probs = (pert_probs**gm_scale) * (unpert_probs ** (1 - gm_scale)) |
| pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) |
|
|
| |
| if torch.sum(pert_probs) <= 1: |
| pert_probs = pert_probs / torch.sum(pert_probs) |
|
|
| else: |
| pert_logits = top_k_filter(pert_logits, k=top_k) |
| pert_probs = nn.functional.softmax(pert_logits, dim=-1) |
|
|
| |
| if sample: |
| last = torch.multinomial(pert_probs, num_samples=1) |
|
|
| else: |
| _, last = torch.topk(pert_probs, k=1, dim=-1) |
|
|
| |
| output_so_far = last if output_so_far is None else torch.cat((output_so_far, last), dim=1) |
|
|
| print(tokenizer.decode(output_so_far.tolist()[0])) |
|
|
| return output_so_far, unpert_discrim_loss, loss_in_time |
|
|
|
|
| def set_generic_model_params(discrim_weights, discrim_meta): |
| if discrim_weights is None: |
| raise ValueError("When using a generic discriminator, discrim_weights need to be specified") |
| if discrim_meta is None: |
| raise ValueError("When using a generic discriminator, discrim_meta need to be specified") |
|
|
| with open(discrim_meta, "r") as discrim_meta_file: |
| meta = json.load(discrim_meta_file) |
| meta["path"] = discrim_weights |
| DISCRIMINATOR_MODELS_PARAMS["generic"] = meta |
|
|
|
|
| def run_pplm_example( |
| pretrained_model="openai-community/gpt2-medium", |
| cond_text="", |
| uncond=False, |
| num_samples=1, |
| bag_of_words=None, |
| discrim=None, |
| discrim_weights=None, |
| discrim_meta=None, |
| class_label=-1, |
| length=100, |
| stepsize=0.02, |
| temperature=1.0, |
| top_k=10, |
| sample=False, |
| num_iterations=3, |
| grad_length=10000, |
| horizon_length=1, |
| window_length=0, |
| decay=False, |
| gamma=1.5, |
| gm_scale=0.9, |
| kl_scale=0.01, |
| seed=0, |
| no_cuda=False, |
| colorama=False, |
| repetition_penalty=1.0, |
| ): |
| |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
|
|
| |
| device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu" |
|
|
| if discrim == "generic": |
| set_generic_model_params(discrim_weights, discrim_meta) |
|
|
| if discrim is not None: |
| pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim]["pretrained_model"] |
| print("discrim = {}, pretrained_model set to discriminator's = {}".format(discrim, pretrained_model)) |
|
|
| |
| model = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=True) |
| model.to(device) |
| model.eval() |
|
|
| |
| tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model) |
|
|
| |
| for param in model.parameters(): |
| param.requires_grad = False |
|
|
| |
| if uncond: |
| tokenized_cond_text = tokenizer.encode([tokenizer.bos_token]) |
| else: |
| raw_text = cond_text |
| while not raw_text: |
| print("Did you forget to add `--cond_text`? ") |
| raw_text = input("Model prompt >>> ") |
| tokenized_cond_text = tokenizer.encode(tokenizer.bos_token + raw_text) |
|
|
| print("= Prefix of sentence =") |
| print(tokenizer.decode(tokenized_cond_text)) |
| print() |
|
|
| |
|
|
| |
| |
| unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation( |
| model=model, |
| tokenizer=tokenizer, |
| context=tokenized_cond_text, |
| device=device, |
| num_samples=num_samples, |
| bag_of_words=bag_of_words, |
| discrim=discrim, |
| class_label=class_label, |
| length=length, |
| stepsize=stepsize, |
| temperature=temperature, |
| top_k=top_k, |
| sample=sample, |
| num_iterations=num_iterations, |
| grad_length=grad_length, |
| horizon_length=horizon_length, |
| window_length=window_length, |
| decay=decay, |
| gamma=gamma, |
| gm_scale=gm_scale, |
| kl_scale=kl_scale, |
| repetition_penalty=repetition_penalty, |
| ) |
|
|
| |
| unpert_gen_text = tokenizer.decode(unpert_gen_tok_text.tolist()[0]) |
|
|
| print("=" * 80) |
| print("= Unperturbed generated text =") |
| print(unpert_gen_text) |
| print() |
|
|
| generated_texts = [] |
|
|
| bow_word_ids = set() |
| if bag_of_words and colorama: |
| bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), tokenizer) |
| for single_bow_list in bow_indices: |
| |
| filtered = list(filter(lambda x: len(x) <= 1, single_bow_list)) |
| |
| bow_word_ids.update(w[0] for w in filtered) |
|
|
| |
| for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts): |
| try: |
| |
| if colorama: |
| import colorama |
|
|
| pert_gen_text = "" |
| for word_id in pert_gen_tok_text.tolist()[0]: |
| if word_id in bow_word_ids: |
| pert_gen_text += "{}{}{}".format( |
| colorama.Fore.RED, |
| tokenizer.decode([word_id]), |
| colorama.Style.RESET_ALL, |
| ) |
| else: |
| pert_gen_text += tokenizer.decode([word_id]) |
| else: |
| pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0]) |
|
|
| print("= Perturbed generated text {} =".format(i + 1)) |
| print(pert_gen_text) |
| print() |
| except Exception as exc: |
| print("Ignoring error while generating perturbed text:", exc) |
|
|
| |
| generated_texts.append((tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text)) |
|
|
| return |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--pretrained_model", |
| "-M", |
| type=str, |
| default="openai-community/gpt2-medium", |
| help="pretrained model name or path to local checkpoint", |
| ) |
| parser.add_argument("--cond_text", type=str, default="The lake", help="Prefix texts to condition on") |
| parser.add_argument("--uncond", action="store_true", help="Generate from end-of-text as prefix") |
| parser.add_argument( |
| "--num_samples", |
| type=int, |
| default=1, |
| help="Number of samples to generate from the modified latents", |
| ) |
| parser.add_argument( |
| "--bag_of_words", |
| "-B", |
| type=str, |
| default=None, |
| help=( |
| "Bags of words used for PPLM-BoW. " |
| "Either a BOW id (see list in code) or a filepath. " |
| "Multiple BoWs separated by ;" |
| ), |
| ) |
| parser.add_argument( |
| "--discrim", |
| "-D", |
| type=str, |
| default=None, |
| choices=("clickbait", "sentiment", "toxicity", "generic"), |
| help="Discriminator to use", |
| ) |
| parser.add_argument( |
| "--discrim_weights", |
| type=str, |
| default=None, |
| help="Weights for the generic discriminator", |
| ) |
| parser.add_argument( |
| "--discrim_meta", |
| type=str, |
| default=None, |
| help="Meta information for the generic discriminator", |
| ) |
| parser.add_argument( |
| "--class_label", |
| type=int, |
| default=-1, |
| help="Class label used for the discriminator", |
| ) |
| parser.add_argument("--length", type=int, default=100) |
| parser.add_argument("--stepsize", type=float, default=0.02) |
| parser.add_argument("--temperature", type=float, default=1.0) |
| parser.add_argument("--top_k", type=int, default=10) |
| parser.add_argument("--sample", action="store_true", help="Generate from end-of-text as prefix") |
| parser.add_argument("--num_iterations", type=int, default=3) |
| parser.add_argument("--grad_length", type=int, default=10000) |
| parser.add_argument( |
| "--window_length", |
| type=int, |
| default=0, |
| help="Length of past which is being optimized; 0 corresponds to infinite window length", |
| ) |
| parser.add_argument( |
| "--horizon_length", |
| type=int, |
| default=1, |
| help="Length of future to optimize over", |
| ) |
| parser.add_argument("--decay", action="store_true", help="whether to decay or not") |
| parser.add_argument("--gamma", type=float, default=1.5) |
| parser.add_argument("--gm_scale", type=float, default=0.9) |
| parser.add_argument("--kl_scale", type=float, default=0.01) |
| parser.add_argument("--seed", type=int, default=0) |
| parser.add_argument("--no_cuda", action="store_true", help="no cuda") |
| parser.add_argument("--colorama", action="store_true", help="colors keywords") |
| parser.add_argument( |
| "--repetition_penalty", |
| type=float, |
| default=1.0, |
| help="Penalize repetition. More than 1.0 -> less repetition", |
| ) |
|
|
| args = parser.parse_args() |
| run_pplm_example(**vars(args)) |
|
|