Spaces:
No application file
No application file
File size: 5,968 Bytes
4f2b2f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import torch
import numpy as np
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import time
import random
import re
from datasets import load_dataset
from parsers import Parser, is_equiv
import torch.distributed as dist
from torch.utils.data import DataLoader
GSM_SYSTEM_PROMPT = """You are a math expert. You will be given a question to solve. Solve it step by step. Wrap the final answer in a \\boxed{}.
Respond in the following format:
<reasoning>
Your reasoning here
</reasoning>
<answer>
\\boxed{...}
</answer>"""
# -----------------------------------------------------------------------------
# helpers
# -----------------------------------------------------------------------------
def extract_rationale_and_answer(answer_field: str):
"""Split raw GSM8K `answer` into (reasoning_text, final_answer_str)."""
if "####" in answer_field:
rationale, final_ans = answer_field.split("####", 1)
return rationale.strip(), final_ans.strip()
# fallback – dataset might already be split elsewhere
return "", answer_field.strip()
# -----------------------------------------------------------------------------
# data preprocessing code
# -----------------------------------------------------------------------------
def preprocess_gsm8k(
split: str, tokenizer: AutoTokenizer, max_length: int
):
ds = load_dataset("gsm8k", "main", split=split)
prompt_builder = GSM8KDataset(tokenizer, num_examples=0, add_reasoning=True)
preprocessed = []
for ex in tqdm(ds, desc = "Preprocessing"):
q = ex["question"].strip()
rat, ans = extract_rationale_and_answer(ex["answer"])
prompt_text = prompt_builder.create_prompt(q)
target_txt = f"{rat}</reasoning>\n<answer>\\boxed{{{ans}}}</answer>"
full_txt = prompt_text + target_txt
enc = tokenizer(full_txt, truncation=True, max_length=max_length,
padding="max_length", return_tensors="pt")
input_ids = enc.input_ids.squeeze(0)
prompt_len = len(tokenizer(prompt_text).input_ids)
preprocessed.append({"input_ids": input_ids, "prompt_lengths": prompt_len})
# ---- shuffle & split like your original helper ----
random.shuffle(preprocessed)
return preprocessed
class GSM8KDataset(torch.utils.data.Dataset):
def __init__(
self,
tokenizer,
num_examples=0,
add_reasoning=True,
system_prompt=GSM_SYSTEM_PROMPT,
subsample=-1,
):
self.tokenizer = tokenizer
self.num_examples = num_examples
self.add_reasoning = add_reasoning
self.system_prompt = system_prompt
self.load_test_dataset()
self.create_few_shot_prompt()
self.subsample = (
np.random.choice(len(self.dataset), subsample, replace=False)
if subsample != -1
else np.arange(len(self.dataset))
)
print(f"evaluating {len(self.subsample)} examples")
assert subsample <= len(self.dataset), "Subsample size is greater than dataset size"
def __len__(self):
return len(self.subsample)
def load_test_dataset(self):
self.dataset = load_dataset("gsm8k", "main", split="test")
def create_prompt(self, input_text):
# Format similar to your chat function
if self.num_examples > 0:
prompt = f"{self.few_shot_prompt}\n\nQuestion: {input_text}\nAnswer:\n"
else:
prompt = input_text
messages = [{"role": "user", "content": self.system_prompt + "\n\n" + prompt}]
user_input = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
if self.add_reasoning:
return user_input + "<reasoning>"
else:
return user_input
def load_few_shot_examples(self):
if isinstance(self.dataset, GSM8KDataset):
train_data = load_dataset("gsm8k", "main", split="train")
examples = random.sample(range(len(train_data)), self.num_examples)
return [train_data[example] for example in examples]
else:
return []
def create_few_shot_prompt(self):
"""Create few-shot prompt from dataset examples"""
few_shot_examples = self.load_few_shot_examples()
formatted_examples = []
for example in few_shot_examples:
input_text = example["question"]
answer = example["answer"]
formatted_examples.append(f"Question: {input_text}\nAnswer:\n{answer}")
self.few_shot_prompt = "\n\n".join(formatted_examples)
def __getitem__(self, idx):
question = self.dataset[self.subsample[idx].item()]["question"]
answer = Parser.extract_answer_gsm8k(self.dataset[self.subsample[idx].item()]["answer"])
prompt = self.create_prompt(question)
return prompt, question, answer
def collate_fn(self, batch):
prompts = [item[0] for item in batch]
questions = [item[1] for item in batch]
answers = [item[2] for item in batch]
input_ids = self.tokenizer(
prompts, padding_side="left", return_tensors="pt", padding="longest"
).input_ids
return {"input_ids": input_ids, "questions": questions, "answers": answers, "prompts": prompts}
# test code
# if __name__ == "__main__":
# train_data, test_data = preprocess_gsm8k(
# split="train", model_name="GSAI-ML/LLaDA-8B-Base", max_length=4096, test_split=0.01
# )
# def collate_fn(batch):
# ids = torch.stack([item["input_ids"] for item in batch])
# plen = torch.tensor([item["prompt_len"] for item in batch])
# return {"input_ids": ids, "prompt_lengths": plen}
# loader = DataLoader(train_data, batch_size=2, collate_fn=collate_fn)
# for batch in loader:
# print(batch["input_ids"].shape)
# print(batch["prompt_lengths"])
# break |