import torch
import numpy as np
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import time
import random
import re
from datasets import load_dataset
from parsers import Parser, is_equiv
import torch.distributed as dist
from torch.utils.data import DataLoader
GSM_SYSTEM_PROMPT = """You are a math expert. You will be given a question to solve. Solve it step by step. Wrap the final answer in a \\boxed{}.
Respond in the following format:
Your reasoning here
\\boxed{...}
"""
# -----------------------------------------------------------------------------
# helpers
# -----------------------------------------------------------------------------
def extract_rationale_and_answer(answer_field: str):
"""Split raw GSM8K `answer` into (reasoning_text, final_answer_str)."""
if "####" in answer_field:
rationale, final_ans = answer_field.split("####", 1)
return rationale.strip(), final_ans.strip()
# fallback – dataset might already be split elsewhere
return "", answer_field.strip()
# -----------------------------------------------------------------------------
# data preprocessing code
# -----------------------------------------------------------------------------
def preprocess_gsm8k(
split: str, tokenizer: AutoTokenizer, max_length: int
):
ds = load_dataset("gsm8k", "main", split=split)
prompt_builder = GSM8KDataset(tokenizer, num_examples=0, add_reasoning=True)
preprocessed = []
for ex in tqdm(ds, desc = "Preprocessing"):
q = ex["question"].strip()
rat, ans = extract_rationale_and_answer(ex["answer"])
prompt_text = prompt_builder.create_prompt(q)
target_txt = f"{rat}\n\\boxed{{{ans}}}"
full_txt = prompt_text + target_txt
enc = tokenizer(full_txt, truncation=True, max_length=max_length,
padding="max_length", return_tensors="pt")
input_ids = enc.input_ids.squeeze(0)
prompt_len = len(tokenizer(prompt_text).input_ids)
preprocessed.append({"input_ids": input_ids, "prompt_lengths": prompt_len})
# ---- shuffle & split like your original helper ----
random.shuffle(preprocessed)
return preprocessed
class GSM8KDataset(torch.utils.data.Dataset):
def __init__(
self,
tokenizer,
num_examples=0,
add_reasoning=True,
system_prompt=GSM_SYSTEM_PROMPT,
subsample=-1,
):
self.tokenizer = tokenizer
self.num_examples = num_examples
self.add_reasoning = add_reasoning
self.system_prompt = system_prompt
self.load_test_dataset()
self.create_few_shot_prompt()
self.subsample = (
np.random.choice(len(self.dataset), subsample, replace=False)
if subsample != -1
else np.arange(len(self.dataset))
)
print(f"evaluating {len(self.subsample)} examples")
assert subsample <= len(self.dataset), "Subsample size is greater than dataset size"
def __len__(self):
return len(self.subsample)
def load_test_dataset(self):
self.dataset = load_dataset("gsm8k", "main", split="test")
def create_prompt(self, input_text):
# Format similar to your chat function
if self.num_examples > 0:
prompt = f"{self.few_shot_prompt}\n\nQuestion: {input_text}\nAnswer:\n"
else:
prompt = input_text
messages = [{"role": "user", "content": self.system_prompt + "\n\n" + prompt}]
user_input = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
if self.add_reasoning:
return user_input + ""
else:
return user_input
def load_few_shot_examples(self):
if isinstance(self.dataset, GSM8KDataset):
train_data = load_dataset("gsm8k", "main", split="train")
examples = random.sample(range(len(train_data)), self.num_examples)
return [train_data[example] for example in examples]
else:
return []
def create_few_shot_prompt(self):
"""Create few-shot prompt from dataset examples"""
few_shot_examples = self.load_few_shot_examples()
formatted_examples = []
for example in few_shot_examples:
input_text = example["question"]
answer = example["answer"]
formatted_examples.append(f"Question: {input_text}\nAnswer:\n{answer}")
self.few_shot_prompt = "\n\n".join(formatted_examples)
def __getitem__(self, idx):
question = self.dataset[self.subsample[idx].item()]["question"]
answer = Parser.extract_answer_gsm8k(self.dataset[self.subsample[idx].item()]["answer"])
prompt = self.create_prompt(question)
return prompt, question, answer
def collate_fn(self, batch):
prompts = [item[0] for item in batch]
questions = [item[1] for item in batch]
answers = [item[2] for item in batch]
input_ids = self.tokenizer(
prompts, padding_side="left", return_tensors="pt", padding="longest"
).input_ids
return {"input_ids": input_ids, "questions": questions, "answers": answers, "prompts": prompts}
# test code
# if __name__ == "__main__":
# train_data, test_data = preprocess_gsm8k(
# split="train", model_name="GSAI-ML/LLaDA-8B-Base", max_length=4096, test_split=0.01
# )
# def collate_fn(batch):
# ids = torch.stack([item["input_ids"] for item in batch])
# plen = torch.tensor([item["prompt_len"] for item in batch])
# return {"input_ids": ids, "prompt_lengths": plen}
# loader = DataLoader(train_data, batch_size=2, collate_fn=collate_fn)
# for batch in loader:
# print(batch["input_ids"].shape)
# print(batch["prompt_lengths"])
# break