c2cite / misc /mmlu_evaluate.py
loadingy's picture
first push
51be264
import csv
import json
import logging
from typing import List
import datasets
import fire
import torch
import moe_peft
choices_map = ["A", "B", "C", "D"]
def format_subject(subject):
lst = subject.split("_")
sjt = ""
for entry in lst:
sjt += " " + entry
return sjt
def format_prompt(data_point, with_answer=True):
question = data_point["question"].strip()
choices = "".join(
[
f"{key}. {choice}\n"
for key, choice in zip(choices_map, data_point["choices"])
]
)
prompt = f"{question}\n{choices}Answer:"
if with_answer:
prompt += " {}\n\n".format(choices_map[data_point["answer"]])
return prompt
def prepare_data(
tokenizer: moe_peft.Tokenizer,
subject: str,
dev_data: datasets.Dataset,
test_data: datasets.Dataset,
k_shots=5,
max_seq_len=2048,
batch_padding=True,
):
sequence_lengths = []
batch_tokens = []
batch_labels = []
atten_masks = []
max_tokens_len = 0
tokens = None
for test_data_point in test_data:
test_prompt = format_prompt(test_data_point, False)
dev_prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
format_subject(subject)
)
k = k_shots
for dev_data_point in dev_data:
k -= 1
prompt = format_prompt(dev_data_point)
input_ids = tokenizer.encode(dev_prompt + prompt + test_prompt)
if len(input_ids) <= max_seq_len:
tokens = input_ids
dev_prompt += prompt
else:
k = 0
if k <= 0:
break
max_tokens_len = max(len(tokens), max_tokens_len)
batch_tokens.append(tokens)
batch_labels.append(test_data_point["answer"])
if batch_padding:
max_seq_len = min(max_seq_len, max_tokens_len)
logging.info(f"Max sequence length: {max_seq_len}")
for tokens in batch_tokens:
if batch_padding:
sequence_lengths.append(len(tokens) - 1)
while len(tokens) < max_seq_len:
tokens.append(tokenizer.pad_id_)
else:
sequence_lengths.append(-1)
atten_masks.append(tokenizer.mask_from(tokens))
return sequence_lengths, batch_tokens, atten_masks, batch_labels
@torch.inference_mode()
def evaluate(
subject: str,
tokenizer: moe_peft.Tokenizer,
model: moe_peft.LLMModel,
adapter_names: List[str],
batch_size: int = 2,
max_seq_len: int = 2048,
):
# prepare data
mmlu = datasets.load_dataset("cais/mmlu", subject)
sequence_lengths, batch_tokens, atten_masks, batch_labels = prepare_data(
tokenizer, subject, mmlu["dev"], mmlu["test"], 5, max_seq_len, batch_size > 1
)
# load adapters
results = {}
for name in adapter_names:
results[name] = []
# prepare for evaluate
sequence_lengths = torch.tensor(
sequence_lengths, dtype=torch.long, device=model.device_
)
label_indices = [0] * len(choices_map)
for idx, text in enumerate(choices_map):
ids = tokenizer.encode(text)
label_indices[idx] = ids[-1]
label_indices = torch.tensor(label_indices, dtype=torch.long, device=model.device_)
start_pos = 0
while start_pos < len(batch_tokens):
end_pos = min(len(batch_tokens), start_pos + batch_size)
logging.info(f"evaluation step: {start_pos}/{len(batch_tokens)}")
bsz = end_pos - start_pos
batch_data_config = []
batch_start_idx = 0
for name in adapter_names:
batch_data_config.append(
moe_peft.LLMBatchConfig(
adapter_name_=name,
batch_start_idx_=batch_start_idx,
batch_end_idx_=batch_start_idx + bsz,
)
)
batch_start_idx += bsz
input_args = moe_peft.LLMModelInput(
batch_configs_=batch_data_config,
batch_tokens_=batch_tokens[start_pos:end_pos] * len(adapter_names),
batch_masks_=atten_masks[start_pos:end_pos] * len(adapter_names),
inference_mode_=True,
)
outputs = model.forward(input_args)
labels = torch.tensor(
batch_labels[start_pos:end_pos], dtype=torch.long, device=model.device_
)
for output in outputs:
logits = output.logits
logits = logits[
torch.arange(bsz, device=logits.device),
sequence_lengths[start_pos:end_pos],
]
logits = logits[:, label_indices]
logits = logits.softmax(-1).argmax(-1)
result = (logits == labels).int().tolist()
results[output.adapter_name].extend(result)
for name, result in results.items():
acc = sum(result) / len(result)
logging.info(f" {name} accuracy: {acc}")
start_pos = end_pos
return results
mmlu_subcategories = {
"abstract_algebra": ["math"],
"anatomy": ["health"],
"astronomy": ["physics"],
"business_ethics": ["business"],
"clinical_knowledge": ["health"],
"college_biology": ["biology"],
"college_chemistry": ["chemistry"],
"college_computer_science": ["computer science"],
"college_mathematics": ["math"],
"college_medicine": ["health"],
"college_physics": ["physics"],
"computer_security": ["computer science"],
"conceptual_physics": ["physics"],
"econometrics": ["economics"],
"electrical_engineering": ["engineering"],
"elementary_mathematics": ["math"],
"formal_logic": ["philosophy"],
"global_facts": ["other"],
"high_school_biology": ["biology"],
"high_school_chemistry": ["chemistry"],
"high_school_computer_science": ["computer science"],
"high_school_european_history": ["history"],
"high_school_geography": ["geography"],
"high_school_government_and_politics": ["politics"],
"high_school_macroeconomics": ["economics"],
"high_school_mathematics": ["math"],
"high_school_microeconomics": ["economics"],
"high_school_physics": ["physics"],
"high_school_psychology": ["psychology"],
"high_school_statistics": ["math"],
"high_school_us_history": ["history"],
"high_school_world_history": ["history"],
"human_aging": ["health"],
"human_sexuality": ["culture"],
"international_law": ["law"],
"jurisprudence": ["law"],
"logical_fallacies": ["philosophy"],
"machine_learning": ["computer science"],
"management": ["business"],
"marketing": ["business"],
"medical_genetics": ["health"],
"miscellaneous": ["other"],
"moral_disputes": ["philosophy"],
"moral_scenarios": ["philosophy"],
"nutrition": ["health"],
"philosophy": ["philosophy"],
"prehistory": ["history"],
"professional_accounting": ["other"],
"professional_law": ["law"],
"professional_medicine": ["health"],
"professional_psychology": ["psychology"],
"public_relations": ["politics"],
"security_studies": ["politics"],
"sociology": ["culture"],
"us_foreign_policy": ["politics"],
"virology": ["health"],
"world_religions": ["philosophy"],
}
mmlu_categories = {
"STEM": [
"physics",
"chemistry",
"biology",
"computer science",
"math",
"engineering",
],
"humanities": ["history", "philosophy", "law"],
"social sciences": ["politics", "culture", "economics", "geography", "psychology"],
"other (business, health, misc.)": ["other", "business", "health"],
}
model_dtypes = {
"4bit": {"bits": 4, "load_dtype": torch.float32},
"8bit": {"bits": 8, "load_dtype": torch.float32},
"16bit": {"load_dtype": torch.bfloat16},
}
def do_evaluate(
model_name: str,
model_dtype: str,
adapter_names: List[str],
batch_size: int = 2,
device: str = moe_peft.executor.default_device_name(),
output: str = "mmlu_scores.csv",
):
tokenizer = moe_peft.Tokenizer(model_name)
model = moe_peft.LLMModel.from_pretrained(
model_name, device=device, **model_dtypes[model_dtype]
)
for name in adapter_names:
logging.info(f"Loading adapter {name}")
if name == "default":
model.init_adapter(moe_peft.AdapterConfig(adapter_name=name))
else:
model.load_adapter(name)
csv_data = [["mmlu_categories", "mmlu_subcategories", "adapter_name", "acc_score"]]
for subject, subcategory in mmlu_subcategories.items():
logging.info(f"Performing MMLU/{subject} Benchmark")
results = evaluate(
subject,
tokenizer,
model,
adapter_names,
batch_size,
model.config_.max_seq_len_,
)
category = None
for category_name, subcategory_names in mmlu_categories.items():
if subcategory[-1] in subcategory_names:
category = category_name
for name, result in results.items():
acc = sum(result) / len(result)
csv_data.append([category, subject, name, acc])
with open(output, "w", newline="") as csvfile:
writer = csv.writer(csvfile)
writer.writerows(csv_data)
def main(config: str):
moe_peft.executor.manual_seed(66)
moe_peft.setup_logging("INFO")
if not moe_peft.executor.check_available():
exit(-1)
with open(config, "r", encoding="utf8") as fp:
mmlu_config = json.load(fp)
do_evaluate(**mmlu_config)
if __name__ == "__main__":
fire.Fire(main)