|
|
import csv |
|
|
import json |
|
|
import logging |
|
|
from typing import List |
|
|
|
|
|
import datasets |
|
|
import fire |
|
|
import torch |
|
|
|
|
|
import moe_peft |
|
|
|
|
|
choices_map = ["A", "B", "C", "D"] |
|
|
|
|
|
|
|
|
def format_subject(subject): |
|
|
lst = subject.split("_") |
|
|
sjt = "" |
|
|
for entry in lst: |
|
|
sjt += " " + entry |
|
|
return sjt |
|
|
|
|
|
|
|
|
def format_prompt(data_point, with_answer=True): |
|
|
question = data_point["question"].strip() |
|
|
choices = "".join( |
|
|
[ |
|
|
f"{key}. {choice}\n" |
|
|
for key, choice in zip(choices_map, data_point["choices"]) |
|
|
] |
|
|
) |
|
|
prompt = f"{question}\n{choices}Answer:" |
|
|
if with_answer: |
|
|
prompt += " {}\n\n".format(choices_map[data_point["answer"]]) |
|
|
return prompt |
|
|
|
|
|
|
|
|
def prepare_data( |
|
|
tokenizer: moe_peft.Tokenizer, |
|
|
subject: str, |
|
|
dev_data: datasets.Dataset, |
|
|
test_data: datasets.Dataset, |
|
|
k_shots=5, |
|
|
max_seq_len=2048, |
|
|
batch_padding=True, |
|
|
): |
|
|
|
|
|
sequence_lengths = [] |
|
|
batch_tokens = [] |
|
|
batch_labels = [] |
|
|
atten_masks = [] |
|
|
|
|
|
max_tokens_len = 0 |
|
|
tokens = None |
|
|
for test_data_point in test_data: |
|
|
test_prompt = format_prompt(test_data_point, False) |
|
|
dev_prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format( |
|
|
format_subject(subject) |
|
|
) |
|
|
k = k_shots |
|
|
for dev_data_point in dev_data: |
|
|
k -= 1 |
|
|
prompt = format_prompt(dev_data_point) |
|
|
input_ids = tokenizer.encode(dev_prompt + prompt + test_prompt) |
|
|
if len(input_ids) <= max_seq_len: |
|
|
tokens = input_ids |
|
|
dev_prompt += prompt |
|
|
else: |
|
|
k = 0 |
|
|
|
|
|
if k <= 0: |
|
|
break |
|
|
|
|
|
max_tokens_len = max(len(tokens), max_tokens_len) |
|
|
batch_tokens.append(tokens) |
|
|
batch_labels.append(test_data_point["answer"]) |
|
|
|
|
|
if batch_padding: |
|
|
max_seq_len = min(max_seq_len, max_tokens_len) |
|
|
logging.info(f"Max sequence length: {max_seq_len}") |
|
|
|
|
|
for tokens in batch_tokens: |
|
|
if batch_padding: |
|
|
sequence_lengths.append(len(tokens) - 1) |
|
|
while len(tokens) < max_seq_len: |
|
|
tokens.append(tokenizer.pad_id_) |
|
|
else: |
|
|
sequence_lengths.append(-1) |
|
|
atten_masks.append(tokenizer.mask_from(tokens)) |
|
|
|
|
|
return sequence_lengths, batch_tokens, atten_masks, batch_labels |
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
def evaluate( |
|
|
subject: str, |
|
|
tokenizer: moe_peft.Tokenizer, |
|
|
model: moe_peft.LLMModel, |
|
|
adapter_names: List[str], |
|
|
batch_size: int = 2, |
|
|
max_seq_len: int = 2048, |
|
|
): |
|
|
|
|
|
|
|
|
mmlu = datasets.load_dataset("cais/mmlu", subject) |
|
|
|
|
|
sequence_lengths, batch_tokens, atten_masks, batch_labels = prepare_data( |
|
|
tokenizer, subject, mmlu["dev"], mmlu["test"], 5, max_seq_len, batch_size > 1 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
results = {} |
|
|
|
|
|
for name in adapter_names: |
|
|
results[name] = [] |
|
|
|
|
|
|
|
|
sequence_lengths = torch.tensor( |
|
|
sequence_lengths, dtype=torch.long, device=model.device_ |
|
|
) |
|
|
|
|
|
label_indices = [0] * len(choices_map) |
|
|
for idx, text in enumerate(choices_map): |
|
|
ids = tokenizer.encode(text) |
|
|
label_indices[idx] = ids[-1] |
|
|
label_indices = torch.tensor(label_indices, dtype=torch.long, device=model.device_) |
|
|
|
|
|
start_pos = 0 |
|
|
while start_pos < len(batch_tokens): |
|
|
end_pos = min(len(batch_tokens), start_pos + batch_size) |
|
|
logging.info(f"evaluation step: {start_pos}/{len(batch_tokens)}") |
|
|
bsz = end_pos - start_pos |
|
|
batch_data_config = [] |
|
|
batch_start_idx = 0 |
|
|
for name in adapter_names: |
|
|
batch_data_config.append( |
|
|
moe_peft.LLMBatchConfig( |
|
|
adapter_name_=name, |
|
|
batch_start_idx_=batch_start_idx, |
|
|
batch_end_idx_=batch_start_idx + bsz, |
|
|
) |
|
|
) |
|
|
batch_start_idx += bsz |
|
|
|
|
|
input_args = moe_peft.LLMModelInput( |
|
|
batch_configs_=batch_data_config, |
|
|
batch_tokens_=batch_tokens[start_pos:end_pos] * len(adapter_names), |
|
|
batch_masks_=atten_masks[start_pos:end_pos] * len(adapter_names), |
|
|
inference_mode_=True, |
|
|
) |
|
|
|
|
|
outputs = model.forward(input_args) |
|
|
|
|
|
labels = torch.tensor( |
|
|
batch_labels[start_pos:end_pos], dtype=torch.long, device=model.device_ |
|
|
) |
|
|
|
|
|
for output in outputs: |
|
|
logits = output.logits |
|
|
logits = logits[ |
|
|
torch.arange(bsz, device=logits.device), |
|
|
sequence_lengths[start_pos:end_pos], |
|
|
] |
|
|
logits = logits[:, label_indices] |
|
|
logits = logits.softmax(-1).argmax(-1) |
|
|
result = (logits == labels).int().tolist() |
|
|
results[output.adapter_name].extend(result) |
|
|
|
|
|
for name, result in results.items(): |
|
|
acc = sum(result) / len(result) |
|
|
logging.info(f" {name} accuracy: {acc}") |
|
|
|
|
|
start_pos = end_pos |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
mmlu_subcategories = { |
|
|
"abstract_algebra": ["math"], |
|
|
"anatomy": ["health"], |
|
|
"astronomy": ["physics"], |
|
|
"business_ethics": ["business"], |
|
|
"clinical_knowledge": ["health"], |
|
|
"college_biology": ["biology"], |
|
|
"college_chemistry": ["chemistry"], |
|
|
"college_computer_science": ["computer science"], |
|
|
"college_mathematics": ["math"], |
|
|
"college_medicine": ["health"], |
|
|
"college_physics": ["physics"], |
|
|
"computer_security": ["computer science"], |
|
|
"conceptual_physics": ["physics"], |
|
|
"econometrics": ["economics"], |
|
|
"electrical_engineering": ["engineering"], |
|
|
"elementary_mathematics": ["math"], |
|
|
"formal_logic": ["philosophy"], |
|
|
"global_facts": ["other"], |
|
|
"high_school_biology": ["biology"], |
|
|
"high_school_chemistry": ["chemistry"], |
|
|
"high_school_computer_science": ["computer science"], |
|
|
"high_school_european_history": ["history"], |
|
|
"high_school_geography": ["geography"], |
|
|
"high_school_government_and_politics": ["politics"], |
|
|
"high_school_macroeconomics": ["economics"], |
|
|
"high_school_mathematics": ["math"], |
|
|
"high_school_microeconomics": ["economics"], |
|
|
"high_school_physics": ["physics"], |
|
|
"high_school_psychology": ["psychology"], |
|
|
"high_school_statistics": ["math"], |
|
|
"high_school_us_history": ["history"], |
|
|
"high_school_world_history": ["history"], |
|
|
"human_aging": ["health"], |
|
|
"human_sexuality": ["culture"], |
|
|
"international_law": ["law"], |
|
|
"jurisprudence": ["law"], |
|
|
"logical_fallacies": ["philosophy"], |
|
|
"machine_learning": ["computer science"], |
|
|
"management": ["business"], |
|
|
"marketing": ["business"], |
|
|
"medical_genetics": ["health"], |
|
|
"miscellaneous": ["other"], |
|
|
"moral_disputes": ["philosophy"], |
|
|
"moral_scenarios": ["philosophy"], |
|
|
"nutrition": ["health"], |
|
|
"philosophy": ["philosophy"], |
|
|
"prehistory": ["history"], |
|
|
"professional_accounting": ["other"], |
|
|
"professional_law": ["law"], |
|
|
"professional_medicine": ["health"], |
|
|
"professional_psychology": ["psychology"], |
|
|
"public_relations": ["politics"], |
|
|
"security_studies": ["politics"], |
|
|
"sociology": ["culture"], |
|
|
"us_foreign_policy": ["politics"], |
|
|
"virology": ["health"], |
|
|
"world_religions": ["philosophy"], |
|
|
} |
|
|
|
|
|
|
|
|
mmlu_categories = { |
|
|
"STEM": [ |
|
|
"physics", |
|
|
"chemistry", |
|
|
"biology", |
|
|
"computer science", |
|
|
"math", |
|
|
"engineering", |
|
|
], |
|
|
"humanities": ["history", "philosophy", "law"], |
|
|
"social sciences": ["politics", "culture", "economics", "geography", "psychology"], |
|
|
"other (business, health, misc.)": ["other", "business", "health"], |
|
|
} |
|
|
|
|
|
|
|
|
model_dtypes = { |
|
|
"4bit": {"bits": 4, "load_dtype": torch.float32}, |
|
|
"8bit": {"bits": 8, "load_dtype": torch.float32}, |
|
|
"16bit": {"load_dtype": torch.bfloat16}, |
|
|
} |
|
|
|
|
|
|
|
|
def do_evaluate( |
|
|
model_name: str, |
|
|
model_dtype: str, |
|
|
adapter_names: List[str], |
|
|
batch_size: int = 2, |
|
|
device: str = moe_peft.executor.default_device_name(), |
|
|
output: str = "mmlu_scores.csv", |
|
|
): |
|
|
tokenizer = moe_peft.Tokenizer(model_name) |
|
|
model = moe_peft.LLMModel.from_pretrained( |
|
|
model_name, device=device, **model_dtypes[model_dtype] |
|
|
) |
|
|
for name in adapter_names: |
|
|
logging.info(f"Loading adapter {name}") |
|
|
if name == "default": |
|
|
model.init_adapter(moe_peft.AdapterConfig(adapter_name=name)) |
|
|
else: |
|
|
model.load_adapter(name) |
|
|
|
|
|
csv_data = [["mmlu_categories", "mmlu_subcategories", "adapter_name", "acc_score"]] |
|
|
for subject, subcategory in mmlu_subcategories.items(): |
|
|
logging.info(f"Performing MMLU/{subject} Benchmark") |
|
|
results = evaluate( |
|
|
subject, |
|
|
tokenizer, |
|
|
model, |
|
|
adapter_names, |
|
|
batch_size, |
|
|
model.config_.max_seq_len_, |
|
|
) |
|
|
category = None |
|
|
for category_name, subcategory_names in mmlu_categories.items(): |
|
|
if subcategory[-1] in subcategory_names: |
|
|
category = category_name |
|
|
for name, result in results.items(): |
|
|
acc = sum(result) / len(result) |
|
|
csv_data.append([category, subject, name, acc]) |
|
|
with open(output, "w", newline="") as csvfile: |
|
|
writer = csv.writer(csvfile) |
|
|
writer.writerows(csv_data) |
|
|
|
|
|
|
|
|
def main(config: str): |
|
|
moe_peft.executor.manual_seed(66) |
|
|
moe_peft.setup_logging("INFO") |
|
|
if not moe_peft.executor.check_available(): |
|
|
exit(-1) |
|
|
with open(config, "r", encoding="utf8") as fp: |
|
|
mmlu_config = json.load(fp) |
|
|
do_evaluate(**mmlu_config) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
fire.Fire(main) |
|
|
|