Delete hellaswag_eval.py
Browse files- hellaswag_eval.py +0 -197
hellaswag_eval.py
DELETED
|
@@ -1,197 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Downloads and evaluates HellaSwag in Python.
|
| 3 |
-
https://github.com/rowanz/hellaswag
|
| 4 |
-
|
| 5 |
-
Example HellaSwag json item:
|
| 6 |
-
|
| 7 |
-
{"ind": 24, "activity_label": "Roof shingle removal", "ctx_a": "A man is sitting on a roof.", "ctx_b": "he", "ctx": "A man is sitting on a roof. he", "split": "val", "split_type": "indomain", "label": 3, "endings": ["is using wrap to wrap a pair of skis.", "is ripping level tiles off.", "is holding a rubik's cube.", "starts pulling up roofing on a roof."], "source_id": "activitynet~v_-JhWjGDPHMY"}
|
| 8 |
-
|
| 9 |
-
ind: dataset ID
|
| 10 |
-
activity_label: The ActivityNet or WikiHow label for this example
|
| 11 |
-
context: There are two formats. The full context is in ctx. When the context ends in an (incomplete) noun phrase, like for ActivityNet, this incomplete noun phrase is in ctx_b, and the context up until then is in ctx_a. This can be useful for models such as BERT that need the last sentence to be complete. However, it's never required. If ctx_b is nonempty, then ctx is the same thing as ctx_a, followed by a space, then ctx_b.
|
| 12 |
-
endings: a list of 4 endings. The correct index is given by label (0,1,2, or 3)
|
| 13 |
-
split: train, val, or test.
|
| 14 |
-
split_type: indomain if the activity label is seen during training, else zeroshot
|
| 15 |
-
source_id: Which video or WikiHow article this example came from
|
| 16 |
-
|
| 17 |
-
gpt2 (124M)
|
| 18 |
-
- eleuther harness reports acc 28.92%, acc_norm 31.14% (multiple choice style)
|
| 19 |
-
- this script: 10042 acc: 0.2859 acc_norm: 0.2955 (completion style)
|
| 20 |
-
|
| 21 |
-
gpt2-xl (1558M)
|
| 22 |
-
- eleuther harness reports acc 40.04%, acc_norm 50.89% (multiple choice style)
|
| 23 |
-
- this script: 10042 acc: 0.3842 acc_norm: 0.4893 (completion style)
|
| 24 |
-
|
| 25 |
-
The validation set of HellaSwag has a total of 10,042 examples.
|
| 26 |
-
"""
|
| 27 |
-
|
| 28 |
-
import os
|
| 29 |
-
import json
|
| 30 |
-
import requests
|
| 31 |
-
import tiktoken
|
| 32 |
-
from tqdm import tqdm
|
| 33 |
-
import torch
|
| 34 |
-
import torch.nn as nn
|
| 35 |
-
from torch.nn import functional as F
|
| 36 |
-
from transformers import GPT2LMHeadModel
|
| 37 |
-
|
| 38 |
-
# -----------------------------------------------------------------------------
|
| 39 |
-
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "hellaswag")
|
| 40 |
-
|
| 41 |
-
def download_file(url: str, fname: str, chunk_size=1024):
|
| 42 |
-
"""Helper function to download a file from a given url"""
|
| 43 |
-
resp = requests.get(url, stream=True)
|
| 44 |
-
total = int(resp.headers.get("content-length", 0))
|
| 45 |
-
with open(fname, "wb") as file, tqdm(
|
| 46 |
-
desc=fname,
|
| 47 |
-
total=total,
|
| 48 |
-
unit="iB",
|
| 49 |
-
unit_scale=True,
|
| 50 |
-
unit_divisor=1024,
|
| 51 |
-
) as bar:
|
| 52 |
-
for data in resp.iter_content(chunk_size=chunk_size):
|
| 53 |
-
size = file.write(data)
|
| 54 |
-
bar.update(size)
|
| 55 |
-
|
| 56 |
-
hellaswags = {
|
| 57 |
-
"train": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_train.jsonl",
|
| 58 |
-
"val": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl",
|
| 59 |
-
"test": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_test.jsonl",
|
| 60 |
-
}
|
| 61 |
-
|
| 62 |
-
enc = tiktoken.get_encoding("gpt2")
|
| 63 |
-
|
| 64 |
-
def download(split):
|
| 65 |
-
"""Downloads HellaSwag DATA_CACHE_DIR"""
|
| 66 |
-
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
|
| 67 |
-
data_url = hellaswags[split]
|
| 68 |
-
data_filename = os.path.join(DATA_CACHE_DIR, f"hellaswag_{split}.jsonl")
|
| 69 |
-
if not os.path.exists(data_filename):
|
| 70 |
-
print(f"Downloading {data_url} to {data_filename}...")
|
| 71 |
-
download_file(data_url, data_filename)
|
| 72 |
-
|
| 73 |
-
def render_example(example):
|
| 74 |
-
"""
|
| 75 |
-
Given the example as a dictionary, render it as three torch tensors:
|
| 76 |
-
- tokens (the tokens of context + completion, of size 4xN, as there are always 4 candidates)
|
| 77 |
-
- mask (is 1 in the region of the candidate completion, where we evaluate likelihoods)
|
| 78 |
-
- label (the index of the correct completion, which we hope has the highest likelihood)
|
| 79 |
-
"""
|
| 80 |
-
ctx = example["ctx"]
|
| 81 |
-
label = example["label"]
|
| 82 |
-
endings = example["endings"]
|
| 83 |
-
# data needed to reproduce this eval on the C size
|
| 84 |
-
data = {
|
| 85 |
-
"label": label,
|
| 86 |
-
"ctx_tokens": None,
|
| 87 |
-
"ending_tokens": [],
|
| 88 |
-
}
|
| 89 |
-
# gather up all the tokens
|
| 90 |
-
ctx_tokens = enc.encode(ctx)
|
| 91 |
-
data["ctx_tokens"] = ctx_tokens
|
| 92 |
-
tok_rows = []
|
| 93 |
-
mask_rows = []
|
| 94 |
-
for end in endings:
|
| 95 |
-
end_tokens = enc.encode(" " + end) # note: prepending " " because GPT-2 tokenizer
|
| 96 |
-
tok_rows.append(ctx_tokens + end_tokens)
|
| 97 |
-
mask_rows.append([0]*len(ctx_tokens) + [1]*len(end_tokens))
|
| 98 |
-
data["ending_tokens"].append(end_tokens)
|
| 99 |
-
|
| 100 |
-
# have to be careful during the collation because the number of tokens in each row can differ
|
| 101 |
-
max_len = max(len(row) for row in tok_rows)
|
| 102 |
-
tokens = torch.zeros((4, max_len), dtype=torch.long)
|
| 103 |
-
mask = torch.zeros((4, max_len), dtype=torch.long)
|
| 104 |
-
for i, (tok_row, mask_row) in enumerate(zip(tok_rows, mask_rows)):
|
| 105 |
-
tokens[i, :len(tok_row)] = torch.tensor(tok_row)
|
| 106 |
-
mask[i, :len(mask_row)] = torch.tensor(mask_row)
|
| 107 |
-
return data, tokens, mask, label
|
| 108 |
-
|
| 109 |
-
def iterate_examples(split):
|
| 110 |
-
# there are 10,042 examples in total in val
|
| 111 |
-
download(split)
|
| 112 |
-
with open(os.path.join(DATA_CACHE_DIR, f"hellaswag_{split}.jsonl"), "r") as f:
|
| 113 |
-
for line in f:
|
| 114 |
-
example = json.loads(line)
|
| 115 |
-
yield example
|
| 116 |
-
|
| 117 |
-
@torch.no_grad()
|
| 118 |
-
def evaluate(model_type, device):
|
| 119 |
-
torch.set_float32_matmul_precision('high') # use tf32
|
| 120 |
-
model = GPT2LMHeadModel.from_pretrained(model_type)
|
| 121 |
-
model.to(device)
|
| 122 |
-
# model = torch.compile(model) # optionally torch compile the model
|
| 123 |
-
num_correct_norm = 0
|
| 124 |
-
num_correct = 0
|
| 125 |
-
num_total = 0
|
| 126 |
-
for example in iterate_examples("val"):
|
| 127 |
-
data, tokens, mask, label = render_example(example)
|
| 128 |
-
tokens = tokens.to(device)
|
| 129 |
-
mask = mask.to(device)
|
| 130 |
-
|
| 131 |
-
# get the logits
|
| 132 |
-
logits = model(tokens).logits
|
| 133 |
-
# evaluate the autoregressive loss at all positions
|
| 134 |
-
shift_logits = (logits[..., :-1, :]).contiguous()
|
| 135 |
-
shift_tokens = (tokens[..., 1:]).contiguous()
|
| 136 |
-
flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
| 137 |
-
flat_shift_tokens = shift_tokens.view(-1)
|
| 138 |
-
shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
|
| 139 |
-
shift_losses = shift_losses.view(tokens.size(0), -1)
|
| 140 |
-
# now get the average loss just for the completion region (where mask == 1), in each row
|
| 141 |
-
shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token
|
| 142 |
-
masked_shift_losses = shift_losses * shift_mask
|
| 143 |
-
# sum and divide by the number of 1s in the mask
|
| 144 |
-
sum_loss = masked_shift_losses.sum(dim=1)
|
| 145 |
-
avg_loss = sum_loss / shift_mask.sum(dim=1)
|
| 146 |
-
# now we have a loss for each of the 4 completions
|
| 147 |
-
# the one with the lowest loss should be the most likely
|
| 148 |
-
pred = sum_loss.argmin().item()
|
| 149 |
-
pred_norm = avg_loss.argmin().item()
|
| 150 |
-
|
| 151 |
-
# accumulate stats
|
| 152 |
-
num_total += 1
|
| 153 |
-
num_correct += int(pred == label)
|
| 154 |
-
num_correct_norm += int(pred_norm == label)
|
| 155 |
-
print(f"{num_total} acc_norm: {num_correct_norm}/{num_total}={num_correct_norm/num_total:.4f}")
|
| 156 |
-
|
| 157 |
-
# debug: pretty print a few examples, and the losses in each case
|
| 158 |
-
if num_total < 10:
|
| 159 |
-
print("---")
|
| 160 |
-
print(f"Context:\n {example['ctx']}")
|
| 161 |
-
print(f"Endings:")
|
| 162 |
-
for i, end in enumerate(example["endings"]):
|
| 163 |
-
print(f"{i} (loss: {avg_loss[i].item():.4f}) {end}")
|
| 164 |
-
print(f"predicted: {pred_norm}, actual: {label}")
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
def get_most_likely_row(tokens, mask, logits):
|
| 168 |
-
"""
|
| 169 |
-
helper function for HellaSwag eval. Takes tokens, mask, and logits,
|
| 170 |
-
returns the index of the completion with the lowest loss
|
| 171 |
-
"""
|
| 172 |
-
# evaluate the autoregressive loss at all positions
|
| 173 |
-
shift_logits = (logits[..., :-1, :]).contiguous()
|
| 174 |
-
shift_tokens = (tokens[..., 1:]).contiguous()
|
| 175 |
-
flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
| 176 |
-
flat_shift_tokens = shift_tokens.view(-1)
|
| 177 |
-
shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
|
| 178 |
-
shift_losses = shift_losses.view(tokens.size(0), -1)
|
| 179 |
-
# now get the average loss just for the completion region (where mask == 1), in each row
|
| 180 |
-
shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token
|
| 181 |
-
masked_shift_losses = shift_losses * shift_mask
|
| 182 |
-
# sum and divide by the number of 1s in the mask
|
| 183 |
-
sum_loss = masked_shift_losses.sum(dim=1)
|
| 184 |
-
avg_loss = sum_loss / shift_mask.sum(dim=1)
|
| 185 |
-
# now we have a loss for each of the 4 completions
|
| 186 |
-
# the one with the lowest loss should be the most likely
|
| 187 |
-
pred_norm = avg_loss.argmin().item()
|
| 188 |
-
return pred_norm
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
if __name__ == "__main__":
|
| 192 |
-
import argparse
|
| 193 |
-
parser = argparse.ArgumentParser()
|
| 194 |
-
parser.add_argument("-m", "--model_type", type=str, default="gpt2", help="the model type to use")
|
| 195 |
-
parser.add_argument("-d", "--device", type=str, default="cuda", help="the device to use")
|
| 196 |
-
args = parser.parse_args()
|
| 197 |
-
evaluate(args.model_type, args.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|