Spaces:
Build error
Build error
| import csv | |
| from dataclasses import dataclass | |
| import io | |
| import json | |
| import logging | |
| import random | |
| import sys | |
| from typing import Dict, List | |
| import pandas as pd | |
| import streamlit as st | |
| import torch | |
| import transformers | |
| from tqdm import tqdm | |
| from autoprompt import utils | |
| import autoprompt.create_trigger as ct | |
| # logging.getLogger("streamlit.caching").addHandler(logging.StreamHandler(sys.stdout)) | |
| # logging.getLogger("streamlit.caching").setLevel(logging.DEBUG) | |
| logger = logging.getLogger(__name__) | |
| with open('assets/sst2_train.jsonl', 'r') as f: | |
| DEFAULT_TRAIN = [json.loads(line) for line in f] | |
| class CacheTest: | |
| """ | |
| Stores whether the train button has been pressed for a given | |
| set of inputs to run_autoprompt. | |
| """ | |
| is_test: bool | |
| class CacheMiss(Exception): | |
| pass | |
| def css_hack(): | |
| """ | |
| Inject some style into this app. ヽ(⌐■_■)ノ | |
| """ | |
| st.markdown( | |
| """ | |
| <style> | |
| code { | |
| color: #eec66d; | |
| } | |
| .css-gtmd9c a { | |
| color: #6f98af; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # Setting eq and frozen ensures that a __hash__ method is generated which is needed for caching to | |
| # properly respond to changed args. | |
| class Args: | |
| # Configurable | |
| template: str | |
| model_name: str | |
| iters: int | |
| num_cand: int | |
| accumulation_steps: int | |
| # Non-Configurable | |
| seed = 0 | |
| sentence_size = 64 | |
| tokenize_labels = True | |
| filter = False | |
| initial_trigger = None | |
| label_field = "label" | |
| bsz = 32 | |
| eval_size = 1 | |
| def from_streamlit(cls): | |
| st.sidebar.image('assets/icon.png', width=150) | |
| st.sidebar.markdown('### Training Parameters') | |
| model_name = st.sidebar.selectbox( | |
| "Model", | |
| options=['roberta-large', 'bert-base-cased'], | |
| help="Language model used for training and evaluation." | |
| ) | |
| iters = int(st.sidebar.number_input( | |
| "Iterations", | |
| value=10, | |
| min_value=1, | |
| max_value=100, | |
| help="Number of trigger search iterations. Larger values may yield better results." | |
| )) | |
| num_cand = int(st.sidebar.number_input( | |
| "Number of Candidates", | |
| value=25, | |
| min_value=1, | |
| max_value=100, | |
| help="Number of candidate trigger token replacements to evaluate during each search " | |
| "iteration. Larger values may yield better results." | |
| )) | |
| accumulation_steps = int(st.sidebar.number_input( | |
| "Gradient Accumulation Steps", | |
| value=1, | |
| min_value=1, | |
| max_value=10, | |
| help="Number of gradient accumulation steps used during training. Larger values may yield " | |
| "better results. Cannot be larger than half the dataset size." | |
| )) | |
| st.sidebar.markdown( | |
| """ | |
| ### Template | |
| Templates define how task-specific inputs are combined with trigger tokens to create | |
| the prompt. They should contain the following placeholders: | |
| - `{sentence}`: Placeholders for the task-specific input fields contain the field name | |
| between curly brackets. For manually entered data the field name is `{sentence}`. For | |
| uploaded csv's, field names should correspond to columns in the csv. | |
| - `[T]`: Placeholder for a trigger token. These are learned from the training data. | |
| - `[P]`: Placeholder for where to insert the [MASK] token that the model will predict | |
| on. | |
| Templates can also include manually written text (such as the | |
| period in the default example below). | |
| """ | |
| ) | |
| template = st.sidebar.text_input("Template", "{sentence} [T] [T] [T] [P].") | |
| return cls( | |
| template=template, | |
| model_name=model_name, | |
| iters=iters, | |
| num_cand=num_cand, | |
| accumulation_steps=accumulation_steps, | |
| ) | |
| # TODO(rloganiv): This probably could use a better name... | |
| class GlobalData: | |
| device: torch.device | |
| config: transformers.PretrainedConfig | |
| model: transformers.PreTrainedModel | |
| tokenizer: transformers.PreTrainedTokenizer | |
| embeddings: torch.nn.Module | |
| embedding_gradient: ct.GradientStorage | |
| predictor: ct.PredictWrapper | |
| def from_pretrained(cls, model_name): | |
| logger.info(f'Loading pretrained model: {model_name}') | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| config, model, tokenizer = ct.load_pretrained(model_name) | |
| model.to(device) | |
| embeddings = ct.get_embeddings(model, config) | |
| embedding_gradient = ct.GradientStorage(embeddings) | |
| predictor = ct.PredictWrapper(model) | |
| return cls( | |
| device, | |
| config, | |
| model, | |
| tokenizer, | |
| embeddings, | |
| embedding_gradient, | |
| predictor | |
| ) | |
| class Dataset: | |
| train: List[int] | |
| label_map: Dict[str, str] | |
| def load_trigger_dataset(dataset, templatizer): | |
| instances = [] | |
| for x in dataset: | |
| instances.append(templatizer(x)) | |
| return instances | |
| def run_autoprompt(args, dataset, cache_test): | |
| if cache_test.is_test: | |
| raise CacheMiss() | |
| ct.set_seed(args.seed) | |
| global_data = GlobalData.from_pretrained(args.model_name) | |
| templatizer = utils.TriggerTemplatizer( | |
| args.template, | |
| global_data.config, | |
| global_data.tokenizer, | |
| label_field=args.label_field, | |
| label_map=dataset.label_map, | |
| tokenize_labels=args.tokenize_labels, | |
| add_special_tokens=True, | |
| ) | |
| evaluation_fn = ct.AccuracyFn(global_data.tokenizer, dataset.label_map, global_data.device, | |
| tokenize_labels=args.tokenize_labels) | |
| # Do not allow for initial trigger specification. | |
| trigger_ids = [global_data.tokenizer.mask_token_id] * templatizer.num_trigger_tokens | |
| trigger_ids = torch.tensor(trigger_ids, device=global_data.device).unsqueeze(0) | |
| best_trigger_ids = trigger_ids.clone() | |
| # Load datasets | |
| logger.info('Loading datasets') | |
| collator = utils.Collator(pad_token_id=global_data.tokenizer.pad_token_id) | |
| try: | |
| train_dataset = load_trigger_dataset(dataset.train, templatizer) | |
| except KeyError as e: | |
| raise RuntimeError( | |
| 'A field in your template is not present in the uploaded dataset. ' | |
| f'Check that there is a column with the name: {e}' | |
| ) | |
| train_loader = torch.utils.data.DataLoader( | |
| train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) | |
| progress = st.progress(0.0) | |
| trigger_placeholder = st.empty() | |
| best_dev_metric = -float('inf') | |
| for i in range(args.iters): | |
| logger.info(f'Iteration: {i}') | |
| progress.progress(float(i)/args.iters) | |
| current_trigger = ','.join(global_data.tokenizer.convert_ids_to_tokens(best_trigger_ids.squeeze(0))) | |
| trigger_placeholder.markdown(f'**Current trigger**: {current_trigger}') | |
| global_data.model.zero_grad() | |
| train_iter = iter(train_loader) | |
| averaged_grad = None | |
| # Compute gradient of loss | |
| for step in range(args.accumulation_steps): | |
| try: | |
| model_inputs, labels = next(train_iter) | |
| except: | |
| logger.warning( | |
| 'Insufficient data for number of accumulation steps. ' | |
| 'Effective batch size will be smaller than specified.' | |
| ) | |
| break | |
| model_inputs = {k: v.to(global_data.device) for k, v in model_inputs.items()} | |
| labels = labels.to(global_data.device) | |
| predict_logits = global_data.predictor(model_inputs, trigger_ids) | |
| loss = ct.get_loss(predict_logits, labels).mean() | |
| loss.backward() | |
| grad = global_data.embedding_gradient.get() | |
| bsz, _, emb_dim = grad.size() | |
| selection_mask = model_inputs['trigger_mask'].unsqueeze(-1) | |
| grad = torch.masked_select(grad, selection_mask) | |
| grad = grad.view(bsz, templatizer.num_trigger_tokens, emb_dim) | |
| if averaged_grad is None: | |
| averaged_grad = grad.sum(dim=0) / args.accumulation_steps | |
| else: | |
| averaged_grad += grad.sum(dim=0) / args.accumulation_steps | |
| logger.info('Evaluating Candidates') | |
| pbar = tqdm(range(args.accumulation_steps)) | |
| train_iter = iter(train_loader) | |
| token_to_flip = i % templatizer.num_trigger_tokens | |
| candidates = ct.hotflip_attack(averaged_grad[token_to_flip], | |
| global_data.embeddings.weight, | |
| increase_loss=False, | |
| num_candidates=args.num_cand) | |
| current_score = 0 | |
| candidate_scores = torch.zeros(args.num_cand, device=global_data.device) | |
| denom = 0 | |
| for step in pbar: | |
| try: | |
| model_inputs, labels = next(train_iter) | |
| except: | |
| logger.warning( | |
| 'Insufficient data for number of accumulation steps. ' | |
| 'Effective batch size will be smaller than specified.' | |
| ) | |
| break | |
| model_inputs = {k: v.to(global_data.device) for k, v in model_inputs.items()} | |
| labels = labels.to(global_data.device) | |
| with torch.no_grad(): | |
| predict_logits = global_data.predictor(model_inputs, trigger_ids) | |
| eval_metric = evaluation_fn(predict_logits, labels) | |
| # Update current score | |
| current_score += eval_metric.sum() | |
| denom += labels.size(0) | |
| # NOTE: Instead of iterating over tokens to flip we randomly change just one each | |
| # time so the gradients don't get stale. | |
| for i, candidate in enumerate(candidates): | |
| # if candidate.item() in filter_candidates: | |
| # candidate_scores[i] = -1e32 | |
| # continue | |
| temp_trigger = trigger_ids.clone() | |
| temp_trigger[:, token_to_flip] = candidate | |
| with torch.no_grad(): | |
| predict_logits = global_data.predictor(model_inputs, temp_trigger) | |
| eval_metric = evaluation_fn(predict_logits, labels) | |
| candidate_scores[i] += eval_metric.sum() | |
| if (candidate_scores >= current_score).any(): | |
| logger.info('Better trigger detected.') | |
| best_candidate_score = candidate_scores.max() | |
| best_candidate_idx = candidate_scores.argmax() | |
| trigger_ids[:, token_to_flip] = candidates[best_candidate_idx] | |
| logger.info(f'Train metric: {best_candidate_score / (denom + 1e-13): 0.4f}') | |
| # Skip eval | |
| best_trigger_ids = trigger_ids.clone() | |
| progress.progress(1.0) | |
| current_trigger = ','.join(global_data.tokenizer.convert_ids_to_tokens(best_trigger_ids.squeeze(0))) | |
| trigger_placeholder.markdown(f'**Current trigger**: {current_trigger}') | |
| best_trigger_tokens = global_data.tokenizer.convert_ids_to_tokens(best_trigger_ids.squeeze(0)) | |
| train_output = predict_test(map(lambda x: x['sentence'], dataset.train), dataset.label_map, | |
| templatizer, best_trigger_ids, global_data.tokenizer, global_data.predictor, args) | |
| # Streamlit does not like accessing widgets across functions, which is | |
| # problematic for this "live updating" widget which we want to still | |
| # display even if the train output is cached. To get around this, we're | |
| # going to delete the widget and replace it with a very similar looking | |
| # widget outside the function...no one will ever notice ;) | |
| trigger_placeholder.empty() | |
| return ( | |
| best_trigger_tokens, | |
| current_score/denom, | |
| dataset.label_map, | |
| templatizer, | |
| best_trigger_ids, | |
| global_data.tokenizer, | |
| global_data.predictor, | |
| args, | |
| train_output | |
| ) | |
| def predict_test(sentences, label_map, templatizer, best_trigger_ids, tokenizer, predictor, args): | |
| # Evaluate clean | |
| output = { 'sentences': [] } | |
| any_label = None | |
| for label in label_map.values(): | |
| output[label] = [] | |
| any_label = label | |
| output['prompt'] = [] | |
| for sentence in sentences: | |
| model_inputs, _ = templatizer({'sentence': sentence, 'label': any_label}) | |
| model_inputs = {k: v.to(best_trigger_ids.device) for k, v in model_inputs.items()} | |
| prompt_ids = ct.replace_trigger_tokens( | |
| model_inputs, best_trigger_ids, model_inputs['trigger_mask']) | |
| prompt = ' '.join(tokenizer.convert_ids_to_tokens(prompt_ids['input_ids'][0])) | |
| output['prompt'].append(prompt) | |
| predict_logits = predictor(model_inputs, best_trigger_ids) | |
| output['sentences'].append(sentence) | |
| for label in label_map.values(): | |
| label_id = utils.encode_label(tokenizer=tokenizer, label=label, tokenize=args.tokenize_labels) | |
| label_id = label_id.to(best_trigger_ids.device) | |
| label_loss = ct.get_loss(predict_logits, label_id) | |
| # st.write(sentence, label, label_loss) | |
| output[label].append(label_loss.item()) | |
| return output | |
| def manual_dataset(use_defaults): | |
| num_train_instances = st.slider("Number of Train Instances", 4, 32, 8) | |
| any_empty = False | |
| dataset = [] | |
| data_col, label_col = st.columns([3,1]) | |
| for i in range(num_train_instances): | |
| default_data = DEFAULT_TRAIN[i]['sentence'] if use_defaults else '' | |
| default_label = DEFAULT_TRAIN[i]['label'] if use_defaults else '' | |
| with data_col: | |
| data = st.text_input("Train Instance " + str(i+1), default_data) | |
| with label_col: | |
| label = st.text_input("Train Label " + str(i+1), default_label, max_chars=20) | |
| if data == "" or label == "": | |
| any_empty = True | |
| dataset.append({'sentence': data, 'label': label}) | |
| label_set = list(set(map(lambda x: x['label'], dataset))) | |
| label_idx = {x: i for i, x in enumerate(label_set)} | |
| label_map = dict(map(lambda x: (x, x), label_set)) | |
| if any_empty: | |
| st.warning('Waiting for data to be added') | |
| st.stop() | |
| if len(label_set) < 2: | |
| st.warning('Not enough labels') | |
| st.stop() | |
| return Dataset( | |
| train=dataset, | |
| label_map=label_map | |
| ) | |
| def csv_dataset(): | |
| st.markdown(""" | |
| Please upload your training and evaluation csv files. | |
| Format restrictions: | |
| - The file is required to have a header | |
| - The column name of the output field should be `label`. | |
| - Each file should contain no more than 64 rows. | |
| """) | |
| train_csv = st.file_uploader('Train', accept_multiple_files=False) | |
| if train_csv is None: | |
| st.stop() | |
| with io.StringIO(train_csv.getvalue().decode('utf-8')) as f: | |
| reader = csv.DictReader(f) | |
| train_dataset = list(reader) | |
| if len(train_dataset) > 64: | |
| raise ValueError('Train dataset is too large. Please limit the number ' | |
| 'of examples to 64 or less.') | |
| labels = set(x['label'] for x in train_dataset) | |
| label_map = {x: x for x in labels} | |
| return Dataset( | |
| train=train_dataset, | |
| label_map=label_map | |
| ) | |
| def run(): | |
| css_hack() | |
| st.title('AutoPrompt Demo') | |
| st.markdown(''' | |
| For many years, the predominant approach for training machine learning | |
| models to solve NLP tasks has been to use supervised training data to | |
| estimate model parameters using maximum likelihood estimation or some | |
| similar paradigm. Whether fitting a logistic regression model over a | |
| bag-of-words, an LSTM over a sequence of GloVe embeddings, or finetuning a | |
| language model such as ELMo or BERT, the approach is essentially the same. | |
| However, as language models have become more and more capable of accurately | |
| generating plausible text a new possibility for solving classification | |
| tasks has emerged... | |
| ## Prompting | |
| Prompting is the method of converting classification tasks into | |
| *fill-in-the-blanks* problems that can be solved by a language model **without | |
| modifying the model's internals**. For example, to perform sentiment analysis, | |
| we may take the sentence we wish to classify and append the text "Overall, this | |
| movie was ____." and feed it into a language model like so: | |
| ''') | |
| st.image('assets/bert-mouth.png', use_column_width=True) | |
| st.markdown(''' | |
| By measuring whether the language model assigns a higher probability to | |
| words that are associated with a **positive** sentiment ("good", "great", | |
| and "fantastic") vs. words that are associated with a **negative** | |
| sentiment ("bad", "terrible", or "awful") we can infer the | |
| predicted label for the given input. So in this example, because the word "good" | |
| has a higher probability than "bad", the predicted label is **positive**. | |
| ## AutoPrompt | |
| One issue that arises when using prompts is that it is not usually clear | |
| how to best pose a task as a fill-in-the-blanks problem in a way that gets | |
| the most performance from the language model. Even for a simple problem | |
| like sentiment analysis, we don't know whether it is better to ask whether | |
| a movie is good/bad, or whether you feel great/terrible about it, and for | |
| more abstract problems like natural language inference it is difficult to | |
| even know where to start. | |
| To cure this writer's block we introduce **AutoPrompt**, a data-driven | |
| approach for automatic prompt construction. The basic idea is | |
| straightfoward: instead of writing a prompt, a user need only write a | |
| **template** that specfies where the *task inputs* go along with placeholders for | |
| a number of *trigger tokens* that will automatically be learned by the | |
| model and the *predict token* that the model will fill in: | |
| ''') | |
| st.image('assets/template.png', use_column_width=True) | |
| st.markdown( | |
| ''' | |
| In each iteration of the search process: | |
| 1. The template is instantiated using a batch of training inputs. | |
| 2. The loss of the model on each input is measured and used to identify a | |
| number of candidate replacements for the current trigger tokens. | |
| 3. The performance of each candidate is measured on another batch of | |
| training data, and the best performing candidate is used in the next | |
| iteration. | |
| ### Demo | |
| To give a better sense of how AutoPrompt works, we have provided a simple | |
| interactive demo. You can generate a prompt using the training data we have | |
| pre-populated for you, or alternatively write your own training/evaluation | |
| instances or upload them using a csv below. In addition, you can vary | |
| some of the training parameters, as well as the template using the sidebar | |
| on the left. | |
| ''' | |
| ) | |
| args = Args.from_streamlit() | |
| dataset_mode = st.radio('How would you like to input your training data?', | |
| options=['Example Data', 'Manual Input', 'From CSV']) | |
| if dataset_mode == 'Example Data': | |
| dataset = manual_dataset(use_defaults=True) | |
| elif dataset_mode == 'Manual Input': | |
| dataset = manual_dataset(use_defaults=False) | |
| else: | |
| dataset = csv_dataset() | |
| button = st.empty() | |
| clicked = button.button('Train') | |
| if clicked: | |
| trigger_tokens, eval_metric, label_map, templatizer, best_trigger_ids, tokenizer, predictor, args, train_output = run_autoprompt(args, dataset, cache_test=CacheTest(False)) | |
| else: | |
| try: | |
| trigger_tokens, eval_metric, label_map, templatizer, best_trigger_ids, tokenizer, predictor, args, train_output = run_autoprompt(args, dataset, cache_test=CacheTest(True)) | |
| except CacheMiss: | |
| st.stop() | |
| else: | |
| button.empty() | |
| st.markdown(f'**Final trigger**: {", ".join(trigger_tokens)}') | |
| st.dataframe(pd.DataFrame(train_output).style.highlight_min(axis=1, color='#94666b')) | |
| logger.debug('Dev metric') | |
| st.write('Accuracy: ' + str(round(eval_metric.item()*100, 1))) | |
| st.write(""" | |
| Et voila, you've now effectively finetuned a classifier using just a few | |
| kilobytes of parameters (the tokens in the prompt). If you like you can | |
| write down your "model" on the back of a napkin and take it with you. | |
| ### Try it out yourself! | |
| """) | |
| sentence = st.text_input("Sentence", 'Enter a test input here') | |
| pred_output = predict_test([sentence], label_map ,templatizer, best_trigger_ids, tokenizer, predictor, args) | |
| st.dataframe(pd.DataFrame(pred_output).style.highlight_min(axis=1, color='#94666b')) | |
| st.markdown(''' | |
| ## Where can I learn more? | |
| If you are interested in learning more about AutoPrompt we recommend | |
| [reading our paper](https://arxiv.org/abs/2010.15980) and [checking out our | |
| code](https://github.com/ucinlp/autoprompt), or if you'd like you can also | |
| watch our presentation at EMNLP 2020: | |
| ''') | |
| st.components.v1.iframe( | |
| src="https://www.youtube.com/embed/IBMT_oOCBbc", | |
| height=400, | |
| ) | |
| st.markdown('Thanks!') | |
| if __name__ == '__main__': | |
| logging.basicConfig(level=logging.INFO, | |
| stream=sys.stdout) | |
| run() | |