LTPO / main.py
yfan07's picture
Add files using upload-large-folder tool
2fdf3c9 verified
import argparse
import numpy as np
import os
import random
import torch
from data import get_dataset
from extract_judge_answer import extract_answer, extract_true_answer, judge_answer
from transformers import AutoModelForCausalLM, AutoTokenizer
from ltpo import generate
from reward import RewardModel
from tqdm import tqdm
huggingface_token = os.environ['HUGGING_FACE_TOKEN']
def parse_args():
parser = argparse.ArgumentParser(description="Evaluate the model")
parser.add_argument("--dataset", type=str, default="openai/gsm8k", help="Dataset to evaluate")
parser.add_argument("--model_name_or_path", type=str, help="Path to the model")
parser.add_argument("--output_dir", type=str, help="Path to the output directory")
parser.add_argument("--start_data_idx", type=int, default=0, help="Start index of the data to evaluate")
parser.add_argument("--end_data_idx", type=int, default=1319, help="End index of the data to evaluate")
parser.add_argument("--max_new_tokens", type=int, default=4096, help="Number of generated tokens")
parser.add_argument("--device", type=str, default="cuda")
# prompt
parser.add_argument("--solver_prompt_idx", type=int, default=0, help="Index of the solver prompt")
# seed
parser.add_argument("--seed", type=int, default=42, help="Random seed for initialization")
# optimization args
parser.add_argument('--num_thought_tokens', type=int, default=10)
parser.add_argument('--sigma', type=float, default=0.1)
parser.add_argument('--sigma_decay', type=float, default=0.99)
parser.add_argument("--lr", type=float, default=0.03, help="Learning rate")
parser.add_argument("--max_num_steps", type=int, default=10, help="Number of optimization iterations")
#Test
parser.add_argument("--optimize_prefix_tokens", type=int, default=-1,
help="Only optimize the first N latent thought tokens; -1 means optimize all tokens")
parser.add_argument("--reward_mode", type=str, default="all_tokens",
help="Reward aggregation mode, e.g. all_tokens / first_token (for future use)")
# reward model
parser.add_argument("--reward_threshold", type=float, default=-1, help="Threshold for reward to stop optimization")
parser.add_argument("--top_k", type=int, default=10, help="Use top-k most probable tokens to calculate token-level confidence")
parser.add_argument("--disable_conf_reward", action="store_true", help="If set, disable using confidence reward")
parser.add_argument("--disable_best_reward", action="store_true", help="If set, disable using best reward step as output")
# misc
parser.add_argument("--resume", action="store_true", help="Resume training from the last checkpoint")
parser.add_argument("--ckpt_suffix", type=str, default="")
parser.add_argument("--use_auto_grad", action="store_true", help="Use PyTorch's auto-grad")
parser.add_argument("--eval_baseline", action="store_true", help="Evaluate baseline")
parser.add_argument("--verbose", type=int, default=1, help="Print detailed information")
parser.add_argument("--disable_save_logistics", action="store_true", help="Disable saving the logistics.pt")
parser.add_argument("--delta_clip_norm", type=float, default=-1.0,
help="Per-token L2 norm clip for each hidden-state update; -1 disables clipping")
return parser.parse_args()
def set_seed(seed):
'''
Set random seed for reproducibility
Args:
seed: random seed
'''
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
random.seed(seed)
# evaluate function
def main(args):
'''
Evaluate model
Args:
dataset: dataset to evaluate
sample_num: number of samples to evaluate
Returns:
original_accuracy: original generation accuracy
optimized_accuracy: optimized generation accuracy
'''
if args.seed:
set_seed(args.seed)
# set device
if args.device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
# load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=torch.bfloat16,
device_map=device,
token=huggingface_token,
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
token=huggingface_token,
)
# load reward model
reward_model = RewardModel(
model=model,
tokenizer=tokenizer,
num_thought_tokens=args.num_thought_tokens,
)
# load dataset
dataset = get_dataset(
args.dataset,
tokenizer=tokenizer,
prompt_idx=args.solver_prompt_idx,
)
if args.verbose:
print(f"Example: {dataset[0]}")
total = 0
correct = 0
entries = []
model_name = args.model_name_or_path.split("/")[-1]
data_name = args.dataset.split("/")[-1]
conf_suffix = "" if args.disable_conf_reward else "-conf"
if args.eval_baseline:
model.eval()
output_suffix = "-" + args.ckpt_suffix if args.ckpt_suffix else ""
output_dir = f"{args.output_dir}/{model_name}-{data_name}-max_tokens{args.max_new_tokens}-prompt{args.solver_prompt_idx}" + output_suffix
else:
prefix_tag = "all" if args.optimize_prefix_tokens == -1 else f"prefix{args.optimize_prefix_tokens}"
reward_tag = args.reward_mode
output_dir = (
f"{args.output_dir}/{model_name}-{data_name}"
f"-tokens{args.num_thought_tokens}"
f"-lr{args.lr}"
f"-sigma{args.sigma}"
f"-sigdecay{args.sigma_decay}"
f"-steps{args.max_num_steps}"
f"-{prefix_tag}"
f"-reward{reward_tag}"
+ conf_suffix
)
start_data_idx = max(0, args.start_data_idx)
end_data_idx = min(args.end_data_idx, len(dataset))
if args.resume and not args.disable_save_logistics:
print(f"Resume from {output_dir}")
# load logistics
logistics = torch.load(f"{output_dir}/logistics.pt")
start_data_idx = logistics["start_idx"]
correct = logistics["correct"]
total = logistics["total"]
entries = logistics["entries"]
print(f"Start to evaluate {args.dataset} from {start_data_idx} to {end_data_idx}...")
data_idx_list = range(start_data_idx, end_data_idx)
for i in tqdm(data_idx_list):
example = dataset[i]
question = example['question']
if not os.path.exists(output_dir):
os.makedirs(output_dir)
true_answer = extract_true_answer(example["answer"], name=args.dataset)
if args.verbose:
print(f"Index {i}, Question: {question}")
print(f"Index {i}, True answer: {true_answer}")
if true_answer is None:
continue
if args.eval_baseline:
init_reward, best_reward, best_reward_step = None, None, None
inputs = tokenizer.apply_chat_template(
example["prompt"],
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
).to(device)
outputs = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=False,
num_beams=1,
)
output = tokenizer.decode(outputs[0], skip_special_tokens=True)
else:
output, init_reward, best_reward, best_reward_step = generate(
tokenizer=tokenizer,
model=model,
reward_model=reward_model,
question=question,
num_thought_tokens=args.num_thought_tokens,
max_rl_steps=args.max_num_steps,
max_new_tokens=args.max_new_tokens,
reward_threshold=args.reward_threshold,
lr=args.lr,
sigma=args.sigma,
sigma_decay=args.sigma_decay,
use_auto_grad=args.use_auto_grad,
disable_conf_reward=args.disable_conf_reward,
disable_best_reward=args.disable_best_reward,
data_name=args.dataset,
model_name=args.model_name_or_path,
verbose=args.verbose,
top_k=args.top_k,
optimize_prefix_tokens=args.optimize_prefix_tokens,
reward_mode=args.reward_mode,
delta_clip_norm=args.delta_clip_norm,
)
# extract answer from model response
answer = extract_answer(
output,
data_name=args.dataset,
prompt_idx=args.solver_prompt_idx,
model_name=args.model_name_or_path,
)
if args.verbose:
if args.verbose > 1:
print(f"Index {i}, LLM response:\n{output}")
print(f"Index {i}, LLM answer: {answer}")
print(f"Index {i}, True answer: {true_answer}")
print(f"Index {i}, Best reward: {best_reward}, Best reward step: {best_reward_step}")
# judge answer
is_correct = False
if answer is not None:
is_correct = judge_answer(output, true_answer, data_name=args.dataset, prompt_idx=args.solver_prompt_idx)
correct += is_correct
if not args.disable_save_logistics:
entries.append(dict(
data_idx=i,
question=question,
response=output,
answer=answer,
is_correct=is_correct,
init_reward=init_reward,
best_reward=best_reward,
best_reward_step=best_reward_step,
))
total += 1
# save logistics
# save original correct, optimized correct, total, update count
if not args.disable_save_logistics:
torch.save({
"start_idx": i+1,
"total": total,
"correct": correct,
"entries": entries,
}, f"{output_dir}/logistics.pt")
print(f"Current state: correct={correct}, total={total}, accuracy={correct / total:.4f}")
if args.verbose:
for i, entry in enumerate(entries):
if not entry['is_correct']:
continue
print(f"====================== Entry {i} ======================")
print(f">>> Question:\n{entry['question']}")
print(f">>> Response:\n{entry['response']}")
print(f">>> Answer:\n{entry['answer']}")
print(f">>> Data Idx: {entry['data_idx']}")
print(f">>> Best Reward: {entry['best_reward']}, Best Reward Step: {entry['best_reward_step']}")
print(f">>> Final State: correct={correct}, total={total}, accuracy={correct / total:.4f}")
print(f">>> Data Idx with Correct Answer: {[entry['data_idx'] for entry in entries if entry['is_correct']]}")
if __name__ == "__main__":
args = parse_args()
for arg in vars(args):
print(f"-- {arg}: {getattr(args, arg)}")
main(args)