Spaces:
Sleeping
Sleeping
| # force update | |
| import argparse | |
| import nltk | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from nltk import sent_tokenize | |
| from transformers import ( | |
| RobertaTokenizer, | |
| RobertaForMaskedLM, | |
| LogitsProcessorList, | |
| TopKLogitsWarper, | |
| TemperatureLogitsWarper, | |
| TypicalLogitsWarper, | |
| ) | |
| nltk.download('punkt') | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pretrained = "roberta-large" if device == "cuda" else "roberta-base" | |
| tokenizer = RobertaTokenizer.from_pretrained(pretrained) | |
| model = RobertaForMaskedLM.from_pretrained(pretrained) | |
| model = model.to(device) | |
| max_len = 20 | |
| top_k = 100 | |
| temperature = 1 | |
| typical_p = 0 | |
| burnin = 250 | |
| max_iter = 500 | |
| # adapted from https://github.com/nyu-dl/bert-gen | |
| def generate_step(out: object, | |
| gen_idx: int, | |
| top_k: int = top_k, | |
| temperature: float = temperature, | |
| typical_p: float = typical_p, | |
| sample: bool = False) -> list: | |
| """ Generate a word from from out[gen_idx] | |
| args: | |
| - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size | |
| - gen_idx (int): location for which to generate | |
| - top_k (int): if >0, only sample from the top k most probable words | |
| - temperature (float): sampling temperature | |
| - typical_p (float): if >0 use typical sampling | |
| - sample (bool): if True, sample from full distribution. | |
| returns: | |
| - list: batch_size tokens | |
| """ | |
| logits = out.logits[:, gen_idx] | |
| warpers = LogitsProcessorList() | |
| if temperature: | |
| warpers.append(TemperatureLogitsWarper(temperature)) | |
| if top_k > 0: | |
| warpers.append(TopKLogitsWarper(top_k)) | |
| if typical_p > 0: | |
| if typical_p >= 1: | |
| typical_p = 0.999 | |
| warpers.append(TypicalLogitsWarper(typical_p)) | |
| logits = warpers(None, logits) | |
| if sample: | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
| else: | |
| next_tokens = torch.argmax(logits, dim=-1) | |
| return next_tokens.tolist() | |
| # adapted from https://github.com/nyu-dl/bert-gen | |
| def parallel_sequential_generation(seed_text: str, | |
| seed_end_text: str, | |
| max_len: int = max_len, | |
| top_k: int = top_k, | |
| temperature: float = temperature, | |
| typical_p: float = typical_p, | |
| max_iter: int = max_iter, | |
| burnin: int = burnin) -> str: | |
| """ Generate text consistent with preceding and following text | |
| Args: | |
| - seed_text (str): preceding text | |
| - seed_end_text (str): following text | |
| - top_k (int): if >0, only sample from the top k most probable words | |
| - temperature (float): sampling temperature | |
| - typical_p (float): if >0 use typical sampling | |
| - max_iter (int): number of iterations in MCMC | |
| - burnin: during burn-in period, sample from full distribution; afterwards take argmax | |
| Returns: | |
| - string: generated text to insert between seed_text and seed_end_text | |
| """ | |
| inp = tokenizer(seed_text + tokenizer.mask_token * max_len + seed_end_text, | |
| return_tensors='pt') | |
| masked_tokens = np.where( | |
| inp['input_ids'][0].numpy() == tokenizer.mask_token_id)[0] | |
| seed_len = masked_tokens[0] | |
| inp = inp.to(device) | |
| for ii in range(max_iter): | |
| kk = np.random.randint(0, max_len) | |
| idxs = generate_step(model(**inp), | |
| gen_idx=seed_len + kk, | |
| top_k=top_k if (ii >= burnin) else 0, | |
| temperature=temperature, | |
| typical_p=typical_p, | |
| sample=(ii < burnin)) | |
| inp['input_ids'][0][seed_len + kk] = idxs[0] | |
| tokens = inp['input_ids'].cpu().numpy()[0][masked_tokens] | |
| tokens = tokens[(np.where((tokens != tokenizer.eos_token_id) | |
| & (tokens != tokenizer.bos_token_id)))] | |
| return tokenizer.decode(tokens) | |
| def inbertolate(doc: str, | |
| max_len: int = max_len, | |
| top_k: int = top_k, | |
| temperature: float = temperature, | |
| typical_p: float = typical_p, | |
| max_iter: int = max_iter, | |
| burnin: int = burnin) -> str: | |
| """ Pad out document generating every other sentence | |
| Args: | |
| - doc (str): document text | |
| - max_len (int): number of tokens to insert between sentences | |
| - top_k (int): if >0, only sample from the top k most probable words | |
| - temperature (float): sampling temperature | |
| - typical_p (float): if >0 use typical sampling | |
| - max_iter (int): number of iterations in MCMC | |
| - burnin: during burn-in period, sample from full distribution; afterwards take argmax | |
| Returns: | |
| - string: generated text to insert between seed_text and seed_end_text | |
| """ | |
| new_doc = '' | |
| paras = doc.split('\n') | |
| for para in paras: | |
| para = sent_tokenize(para) | |
| if para == '': | |
| new_doc += '\n' | |
| continue | |
| para += [''] | |
| for sentence in range(len(para) - 1): | |
| new_doc += para[sentence] + ' ' | |
| new_doc += parallel_sequential_generation( | |
| para[sentence], | |
| para[sentence + 1], | |
| max_len=max_len, | |
| top_k=top_k, | |
| temperature=float(temperature), | |
| typical_p=typical_p, | |
| burnin=burnin, | |
| max_iter=max_iter) + ' ' | |
| new_doc += '\n' | |
| return new_doc | |
| demo = gr.Interface( | |
| fn=inbertolate, | |
| title="inBERTolate", | |
| description=f"Hit your word count by using BERT ({pretrained}) to pad out your essays!", | |
| inputs=[ | |
| gr.Textbox(label="Text", lines=10), | |
| gr.Slider(label="Maximum length to insert between sentences", | |
| minimum=1, | |
| maximum=40, | |
| step=1, | |
| value=max_len), | |
| gr.Slider(label="Top k", minimum=0, maximum=200, value=top_k), | |
| gr.Slider(label="Temperature", | |
| minimum=0, | |
| maximum=2, | |
| value=temperature), | |
| gr.Slider(label="Typical p", | |
| minimum=0, | |
| maximum=1, | |
| value=typical_p), | |
| gr.Slider(label="Maximum iterations", | |
| minimum=0, | |
| maximum=1000, | |
| value=max_iter), | |
| gr.Slider(label="Burn-in", | |
| minimum=0, | |
| maximum=500, | |
| value=burnin), | |
| ], | |
| outputs=gr.Textbox(label="Expanded text", lines=30)) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--port', type=int) | |
| parser.add_argument('--server', type=int) | |
| args = parser.parse_args() | |
| demo.launch(server_name=args.server or '0.0.0.0', server_port=args.port) | |