| from data import load_gsm8k |
| from utils import model_inference, initialize_text_to_text_model |
| from fire import Fire |
| import re |
| import os |
| from tqdm import tqdm |
| from peft import PeftModel |
|
|
| def extract_num(text): |
| |
| pattern = r'####\s*(\d+)' |
| |
| match = re.search(pattern, text) |
| if match: |
| result = match.group(1) |
| print(text) |
| else: |
| print(text) |
| result = "" |
| try: |
| return int(result.replace(",", "")) |
| except: |
| print(f"'{result}' can't be converted") |
| return 0 |
|
|
| def main(model_name = "llama/llama-2-7b-hf"): |
| _, _, test_set = load_gsm8k() |
| model_type = "CausalLM" |
| model, tokenizer = initialize_text_to_text_model( |
| model_name, model_type, True, flash_attention=True |
| ) |
| model = PeftModel.from_pretrained(model, "checkpoint_dir") |
| model.generation_config.pad_token_id = tokenizer.pad_token_id |
| all = 0 |
| correct = 0 |
| t = tqdm(test_set) |
| for example in t: |
| |
| pred_text = model_inference(model, tokenizer, example['x'], model_type, max_target_length=512) |
| gt = extract_num(example["y"]) |
| pred = extract_num(pred_text) |
| correct += int(gt == pred) |
| all += 1 |
| t.set_description(f"Accuracy: {correct / all * 100:02f}%") |
|
|
| print("Acc:", correct / all) |
| |
| if not os.path.exists("gsm8k_results.txt"): |
| with open("gsm8k_results.txt", "w") as f: |
| f.write("Model Acc\n") |
| with open("gsm8k_results.txt", "a") as f: |
| f.write(f"{model_name} {correct / all}\n") |
|
|
| if __name__ == "__main__": |
| Fire(main) |
|
|
|
|
|
|
|
|