CDKA / eval_gsm8k.py
rainstonee's picture
Upload 33 files
c5d3e8d verified
Raw
History Blame Contribute Delete
1.75 kB
from data import load_gsm8k
from utils import model_inference, initialize_text_to_text_model
from fire import Fire
import re
import os
from tqdm import tqdm
from peft import PeftModel
def extract_num(text):
# Regex pattern to find the number following '####'
pattern = r'####\s*(\d+)'
# Using re.search to find the first match
match = re.search(pattern, text)
if match:
result = match.group(1)
print(text)
else:
print(text)
result = ""
try:
return int(result.replace(",", ""))
except:
print(f"'{result}' can't be converted")
return 0
def main(model_name = "llama/llama-2-7b-hf"):
_, _, test_set = load_gsm8k()
model_type = "CausalLM"
model, tokenizer = initialize_text_to_text_model(
model_name, model_type, True, flash_attention=True
)
model = PeftModel.from_pretrained(model, "checkpoint_dir")
model.generation_config.pad_token_id = tokenizer.pad_token_id
all = 0
correct = 0
t = tqdm(test_set)
for example in t:
# print(example['x'])
pred_text = model_inference(model, tokenizer, example['x'], model_type, max_target_length=512)
gt = extract_num(example["y"])
pred = extract_num(pred_text)
correct += int(gt == pred)
all += 1
t.set_description(f"Accuracy: {correct / all * 100:02f}%")
print("Acc:", correct / all)
# append to gsm8k_results.txt (create if not exists)
if not os.path.exists("gsm8k_results.txt"):
with open("gsm8k_results.txt", "w") as f:
f.write("Model Acc\n")
with open("gsm8k_results.txt", "a") as f:
f.write(f"{model_name} {correct / all}\n")
if __name__ == "__main__":
Fire(main)