import torch import json import re import os import time import random import torch.multiprocessing as mp from tqdm import tqdm from abstract_model import AbstractModel RL_MODEL_PATH = "pathtocontinuoushead" FALLBACK_SFT_PATH = "pathtobasemodel" DATASET_FILES = [ "../bench/mmlu.jsonl", "../bench/gsm8k.jsonl", "../bench/drop.jsonl" ] SAMPLES_PER_BENCHMARK = 1024 MAX_THINKING_STEPS = 256 MAX_TOTAL_LENGTH = 1536 LOG_FILE = "eval_results_random.jsonl" def normalize_text(s): import string if s is None: return "" def remove_articles(text): return re.sub(r'\b(a|an|the)\b', ' ', text) def white_space_fix(text): return ' '.join(text.split()) def remove_punc(text): return ''.join(ch for ch in text if ch not in set(string.punctuation)) return white_space_fix(remove_articles(remove_punc(str(s).lower()))) def extract_answer_content(text): match = re.search(r"(.*?)", text, re.DOTALL) if match: return match.group(1).strip() return None def load_and_sample_data(files, samples_per_file): """ Loads full datasets and randomly samples N items from each. """ final_data = [] for filename in files: if not os.path.exists(filename): print(f"Warning: File {filename} not found. Skipping.") continue # Detect benchmark type fname_lower = filename.lower() if "mmlu" in fname_lower: bench_type = "mmlu" elif "gsm8k" in fname_lower: bench_type = "gsm8k" elif "drop" in fname_lower: bench_type = "drop" else: bench_type = "unknown" print(f"Loading {filename} ({bench_type})...") file_data = [] with open(filename, 'r', encoding='utf-8') as f: for line in f: try: entry = json.loads(line) if "benchmark" not in entry: entry["benchmark"] = bench_type file_data.append(entry) except: continue total_lines = len(file_data) if total_lines > samples_per_file: random.shuffle(file_data) selected_data = file_data[:samples_per_file] print(f" -> Randomly sampled {samples_per_file} from {total_lines} samples.") else: selected_data = file_data print(f" -> Took all {total_lines} samples (less than requested limit).") final_data.extend(selected_data) return final_data def score_sample(pred, truth, benchmark): if benchmark == 'mmlu': p = extract_answer_content(pred) if not p: return False m = re.search(r'([A-D])', p.upper()) return m.group(1) == truth.strip().upper() if m else False elif benchmark == 'gsm8k': p = extract_answer_content(pred) if not p: return False t = truth.split("####")[-1].strip() if "####" in truth else truth.strip() return normalize_text(t) in normalize_text(p) else: p = extract_answer_content(pred) if not p: return False return normalize_text(p) == normalize_text(truth) def gpu(gpu_id, head_path, sft_path, dataset_chunk, results_queue): torch.cuda.set_device(gpu_id) device = f"cuda:{gpu_id}" if not os.path.exists(os.path.join(head_path, "continuous_head.pt")): print(f"[GPU {gpu_id}] Critical: continuous_head.pt not found in {head_path}") return print(f"[GPU {gpu_id}] Loading Model...") try: model = AbstractModel.load_from_directory( head_path, sft_model_path=sft_path, device=device ) except Exception as e: print(f"[GPU {gpu_id}] Error loading model: {e}") return results = [] iterator = tqdm(dataset_chunk, desc=f"GPU {gpu_id}", position=gpu_id, leave=True) for item in iterator: try: sys_prompt = "You are a reasoning assistant. Think step by step before answering." messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": item['question']}] formatted = model.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) input_ids = model.tokenizer(formatted, return_tensors='pt', add_special_tokens=False)['input_ids'].to(device).squeeze(0) out = model.forward( input_ids, max_length=MAX_TOTAL_LENGTH, temperature=0.0, sample=False, no_grad=True, sigma=0.0, max_thinking_steps=MAX_THINKING_STEPS ) full_text = "" for token_id in out['generated_tokens'].tolist(): full_text += model.tokenizer.decode([token_id]) is_correct = score_sample(full_text, item['answer'], item['benchmark']) results.append({ "benchmark": item['benchmark'], "correct": is_correct, "think_steps": out['mode_sequence'].count('A'), "prediction": full_text }) except Exception as e: print(f"[GPU {gpu_id}] Error: {e}") continue results_queue.put(results) def run_evaluation(): all_data = load_and_sample_data(DATASET_FILES, SAMPLES_PER_BENCHMARK) if not all_data: print("No data loaded. Exiting.") return print(f"Total Evaluation Set: {len(all_data)} samples.") mid = len(all_data) // 2 queue = mp.Queue() p1 = mp.Process(target=gpu, args=(0, RL_MODEL_PATH, FALLBACK_SFT_PATH, all_data[:mid], queue)) p2 = mp.Process(target=gpu, args=(1, RL_MODEL_PATH, FALLBACK_SFT_PATH, all_data[mid:], queue)) start_time = time.time() p1.start(); p2.start() final_results = [] for _ in range(2): final_results.extend(queue.get()) p1.join(); p2.join() print(f"Saving detailed logs to {LOG_FILE}...") with open(LOG_FILE, 'w') as f: for r in final_results: f.write(json.dumps(r) + '\n') metrics = {} for res in final_results: b = res['benchmark'] if b not in metrics: metrics[b] = {'correct': [], 'steps': []} metrics[b]['correct'].append(res['correct']) metrics[b]['steps'].append(res['think_steps']) print("\n" + "="*50) print(f"FINAL SCORES (Random Sample N={SAMPLES_PER_BENCHMARK})") print("="*50) for b, d in metrics.items(): acc = sum(d['correct']) / len(d['correct']) * 100 avg_steps = sum(d['steps']) / len(d['steps']) print(f"{b.upper():<10} | Acc: {acc:.2f}% | Avg Steps: {avg_steps:.1f} | N: {len(d['correct'])}") print(f"Total time: {time.time() - start_time:.2f}s") if __name__ == "__main__": mp.set_start_method('spawn', force=True) run_evaluation()