import argparse import numpy as np import os import random import torch from data import get_dataset from extract_judge_answer import extract_answer, extract_true_answer, judge_answer from transformers import AutoModelForCausalLM, AutoTokenizer from ltpo import generate from reward import RewardModel from tqdm import tqdm huggingface_token = os.environ['HUGGING_FACE_TOKEN'] def parse_args(): parser = argparse.ArgumentParser(description="Evaluate the model") parser.add_argument("--dataset", type=str, default="openai/gsm8k", help="Dataset to evaluate") parser.add_argument("--model_name_or_path", type=str, help="Path to the model") parser.add_argument("--output_dir", type=str, help="Path to the output directory") parser.add_argument("--start_data_idx", type=int, default=0, help="Start index of the data to evaluate") parser.add_argument("--end_data_idx", type=int, default=1319, help="End index of the data to evaluate") parser.add_argument("--max_new_tokens", type=int, default=4096, help="Number of generated tokens") parser.add_argument("--device", type=str, default="cuda") # prompt parser.add_argument("--solver_prompt_idx", type=int, default=0, help="Index of the solver prompt") # seed parser.add_argument("--seed", type=int, default=42, help="Random seed for initialization") # optimization args parser.add_argument('--num_thought_tokens', type=int, default=10) parser.add_argument('--sigma', type=float, default=0.1) parser.add_argument('--sigma_decay', type=float, default=0.99) parser.add_argument("--lr", type=float, default=0.03, help="Learning rate") parser.add_argument("--max_num_steps", type=int, default=10, help="Number of optimization iterations") #Test parser.add_argument("--optimize_prefix_tokens", type=int, default=-1, help="Only optimize the first N latent thought tokens; -1 means optimize all tokens") parser.add_argument("--reward_mode", type=str, default="all_tokens", help="Reward aggregation mode, e.g. all_tokens / first_token (for future use)") # reward model parser.add_argument("--reward_threshold", type=float, default=-1, help="Threshold for reward to stop optimization") parser.add_argument("--top_k", type=int, default=10, help="Use top-k most probable tokens to calculate token-level confidence") parser.add_argument("--disable_conf_reward", action="store_true", help="If set, disable using confidence reward") parser.add_argument("--disable_best_reward", action="store_true", help="If set, disable using best reward step as output") # misc parser.add_argument("--resume", action="store_true", help="Resume training from the last checkpoint") parser.add_argument("--ckpt_suffix", type=str, default="") parser.add_argument("--use_auto_grad", action="store_true", help="Use PyTorch's auto-grad") parser.add_argument("--eval_baseline", action="store_true", help="Evaluate baseline") parser.add_argument("--verbose", type=int, default=1, help="Print detailed information") parser.add_argument("--disable_save_logistics", action="store_true", help="Disable saving the logistics.pt") parser.add_argument("--delta_clip_norm", type=float, default=-1.0, help="Per-token L2 norm clip for each hidden-state update; -1 disables clipping") return parser.parse_args() def set_seed(seed): ''' Set random seed for reproducibility Args: seed: random seed ''' torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) random.seed(seed) # evaluate function def main(args): ''' Evaluate model Args: dataset: dataset to evaluate sample_num: number of samples to evaluate Returns: original_accuracy: original generation accuracy optimized_accuracy: optimized generation accuracy ''' if args.seed: set_seed(args.seed) # set device if args.device is None: device = "cuda" if torch.cuda.is_available() else "cpu" else: device = args.device # load model and tokenizer model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, torch_dtype=torch.bfloat16, device_map=device, token=huggingface_token, ) model.eval() tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path, token=huggingface_token, ) # load reward model reward_model = RewardModel( model=model, tokenizer=tokenizer, num_thought_tokens=args.num_thought_tokens, ) # load dataset dataset = get_dataset( args.dataset, tokenizer=tokenizer, prompt_idx=args.solver_prompt_idx, ) if args.verbose: print(f"Example: {dataset[0]}") total = 0 correct = 0 entries = [] model_name = args.model_name_or_path.split("/")[-1] data_name = args.dataset.split("/")[-1] conf_suffix = "" if args.disable_conf_reward else "-conf" if args.eval_baseline: model.eval() output_suffix = "-" + args.ckpt_suffix if args.ckpt_suffix else "" output_dir = f"{args.output_dir}/{model_name}-{data_name}-max_tokens{args.max_new_tokens}-prompt{args.solver_prompt_idx}" + output_suffix else: prefix_tag = "all" if args.optimize_prefix_tokens == -1 else f"prefix{args.optimize_prefix_tokens}" reward_tag = args.reward_mode output_dir = ( f"{args.output_dir}/{model_name}-{data_name}" f"-tokens{args.num_thought_tokens}" f"-lr{args.lr}" f"-sigma{args.sigma}" f"-sigdecay{args.sigma_decay}" f"-steps{args.max_num_steps}" f"-{prefix_tag}" f"-reward{reward_tag}" + conf_suffix ) start_data_idx = max(0, args.start_data_idx) end_data_idx = min(args.end_data_idx, len(dataset)) if args.resume and not args.disable_save_logistics: print(f"Resume from {output_dir}") # load logistics logistics = torch.load(f"{output_dir}/logistics.pt") start_data_idx = logistics["start_idx"] correct = logistics["correct"] total = logistics["total"] entries = logistics["entries"] print(f"Start to evaluate {args.dataset} from {start_data_idx} to {end_data_idx}...") data_idx_list = range(start_data_idx, end_data_idx) for i in tqdm(data_idx_list): example = dataset[i] question = example['question'] if not os.path.exists(output_dir): os.makedirs(output_dir) true_answer = extract_true_answer(example["answer"], name=args.dataset) if args.verbose: print(f"Index {i}, Question: {question}") print(f"Index {i}, True answer: {true_answer}") if true_answer is None: continue if args.eval_baseline: init_reward, best_reward, best_reward_step = None, None, None inputs = tokenizer.apply_chat_template( example["prompt"], add_generation_prompt=True, return_dict=True, return_tensors="pt", ).to(device) outputs = model.generate( **inputs, max_new_tokens=args.max_new_tokens, do_sample=False, num_beams=1, ) output = tokenizer.decode(outputs[0], skip_special_tokens=True) else: output, init_reward, best_reward, best_reward_step = generate( tokenizer=tokenizer, model=model, reward_model=reward_model, question=question, num_thought_tokens=args.num_thought_tokens, max_rl_steps=args.max_num_steps, max_new_tokens=args.max_new_tokens, reward_threshold=args.reward_threshold, lr=args.lr, sigma=args.sigma, sigma_decay=args.sigma_decay, use_auto_grad=args.use_auto_grad, disable_conf_reward=args.disable_conf_reward, disable_best_reward=args.disable_best_reward, data_name=args.dataset, model_name=args.model_name_or_path, verbose=args.verbose, top_k=args.top_k, optimize_prefix_tokens=args.optimize_prefix_tokens, reward_mode=args.reward_mode, delta_clip_norm=args.delta_clip_norm, ) # extract answer from model response answer = extract_answer( output, data_name=args.dataset, prompt_idx=args.solver_prompt_idx, model_name=args.model_name_or_path, ) if args.verbose: if args.verbose > 1: print(f"Index {i}, LLM response:\n{output}") print(f"Index {i}, LLM answer: {answer}") print(f"Index {i}, True answer: {true_answer}") print(f"Index {i}, Best reward: {best_reward}, Best reward step: {best_reward_step}") # judge answer is_correct = False if answer is not None: is_correct = judge_answer(output, true_answer, data_name=args.dataset, prompt_idx=args.solver_prompt_idx) correct += is_correct if not args.disable_save_logistics: entries.append(dict( data_idx=i, question=question, response=output, answer=answer, is_correct=is_correct, init_reward=init_reward, best_reward=best_reward, best_reward_step=best_reward_step, )) total += 1 # save logistics # save original correct, optimized correct, total, update count if not args.disable_save_logistics: torch.save({ "start_idx": i+1, "total": total, "correct": correct, "entries": entries, }, f"{output_dir}/logistics.pt") print(f"Current state: correct={correct}, total={total}, accuracy={correct / total:.4f}") if args.verbose: for i, entry in enumerate(entries): if not entry['is_correct']: continue print(f"====================== Entry {i} ======================") print(f">>> Question:\n{entry['question']}") print(f">>> Response:\n{entry['response']}") print(f">>> Answer:\n{entry['answer']}") print(f">>> Data Idx: {entry['data_idx']}") print(f">>> Best Reward: {entry['best_reward']}, Best Reward Step: {entry['best_reward_step']}") print(f">>> Final State: correct={correct}, total={total}, accuracy={correct / total:.4f}") print(f">>> Data Idx with Correct Answer: {[entry['data_idx'] for entry in entries if entry['is_correct']]}") if __name__ == "__main__": args = parse_args() for arg in vars(args): print(f"-- {arg}: {getattr(args, arg)}") main(args)