granite-abstract / eval_simple.py
Gavin-Wang's picture
scripts
b1b2e62 verified
import torch
import json
import re
import os
import time
import random
import torch.multiprocessing as mp
from tqdm import tqdm
from abstract_model import AbstractModel
RL_MODEL_PATH = "pathtocontinuoushead"
FALLBACK_SFT_PATH = "pathtobasemodel"
DATASET_FILES = [
"../bench/mmlu.jsonl",
"../bench/gsm8k.jsonl",
"../bench/drop.jsonl"
]
SAMPLES_PER_BENCHMARK = 1024
MAX_THINKING_STEPS = 256
MAX_TOTAL_LENGTH = 1536
LOG_FILE = "eval_results_random.jsonl"
def normalize_text(s):
import string
if s is None: return ""
def remove_articles(text): return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text): return ' '.join(text.split())
def remove_punc(text): return ''.join(ch for ch in text if ch not in set(string.punctuation))
return white_space_fix(remove_articles(remove_punc(str(s).lower())))
def extract_answer_content(text):
match = re.search(r"<ANSWER>(.*?)</ANSWER>", text, re.DOTALL)
if match: return match.group(1).strip()
return None
def load_and_sample_data(files, samples_per_file):
"""
Loads full datasets and randomly samples N items from each.
"""
final_data = []
for filename in files:
if not os.path.exists(filename):
print(f"Warning: File {filename} not found. Skipping.")
continue
# Detect benchmark type
fname_lower = filename.lower()
if "mmlu" in fname_lower: bench_type = "mmlu"
elif "gsm8k" in fname_lower: bench_type = "gsm8k"
elif "drop" in fname_lower: bench_type = "drop"
else: bench_type = "unknown"
print(f"Loading {filename} ({bench_type})...")
file_data = []
with open(filename, 'r', encoding='utf-8') as f:
for line in f:
try:
entry = json.loads(line)
if "benchmark" not in entry:
entry["benchmark"] = bench_type
file_data.append(entry)
except: continue
total_lines = len(file_data)
if total_lines > samples_per_file:
random.shuffle(file_data)
selected_data = file_data[:samples_per_file]
print(f" -> Randomly sampled {samples_per_file} from {total_lines} samples.")
else:
selected_data = file_data
print(f" -> Took all {total_lines} samples (less than requested limit).")
final_data.extend(selected_data)
return final_data
def score_sample(pred, truth, benchmark):
if benchmark == 'mmlu':
p = extract_answer_content(pred)
if not p: return False
m = re.search(r'([A-D])', p.upper())
return m.group(1) == truth.strip().upper() if m else False
elif benchmark == 'gsm8k':
p = extract_answer_content(pred)
if not p: return False
t = truth.split("####")[-1].strip() if "####" in truth else truth.strip()
return normalize_text(t) in normalize_text(p)
else:
p = extract_answer_content(pred)
if not p: return False
return normalize_text(p) == normalize_text(truth)
def gpu(gpu_id, head_path, sft_path, dataset_chunk, results_queue):
torch.cuda.set_device(gpu_id)
device = f"cuda:{gpu_id}"
if not os.path.exists(os.path.join(head_path, "continuous_head.pt")):
print(f"[GPU {gpu_id}] Critical: continuous_head.pt not found in {head_path}")
return
print(f"[GPU {gpu_id}] Loading Model...")
try:
model = AbstractModel.load_from_directory(
head_path,
sft_model_path=sft_path,
device=device
)
except Exception as e:
print(f"[GPU {gpu_id}] Error loading model: {e}")
return
results = []
iterator = tqdm(dataset_chunk, desc=f"GPU {gpu_id}", position=gpu_id, leave=True)
for item in iterator:
try:
sys_prompt = "You are a reasoning assistant. Think step by step before answering."
messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": item['question']}]
formatted = model.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
input_ids = model.tokenizer(formatted, return_tensors='pt', add_special_tokens=False)['input_ids'].to(device).squeeze(0)
out = model.forward(
input_ids,
max_length=MAX_TOTAL_LENGTH,
temperature=0.0,
sample=False,
no_grad=True,
sigma=0.0,
max_thinking_steps=MAX_THINKING_STEPS
)
full_text = ""
for token_id in out['generated_tokens'].tolist():
full_text += model.tokenizer.decode([token_id])
is_correct = score_sample(full_text, item['answer'], item['benchmark'])
results.append({
"benchmark": item['benchmark'],
"correct": is_correct,
"think_steps": out['mode_sequence'].count('A'),
"prediction": full_text
})
except Exception as e:
print(f"[GPU {gpu_id}] Error: {e}")
continue
results_queue.put(results)
def run_evaluation():
all_data = load_and_sample_data(DATASET_FILES, SAMPLES_PER_BENCHMARK)
if not all_data:
print("No data loaded. Exiting.")
return
print(f"Total Evaluation Set: {len(all_data)} samples.")
mid = len(all_data) // 2
queue = mp.Queue()
p1 = mp.Process(target=gpu, args=(0, RL_MODEL_PATH, FALLBACK_SFT_PATH, all_data[:mid], queue))
p2 = mp.Process(target=gpu, args=(1, RL_MODEL_PATH, FALLBACK_SFT_PATH, all_data[mid:], queue))
start_time = time.time()
p1.start(); p2.start()
final_results = []
for _ in range(2): final_results.extend(queue.get())
p1.join(); p2.join()
print(f"Saving detailed logs to {LOG_FILE}...")
with open(LOG_FILE, 'w') as f:
for r in final_results: f.write(json.dumps(r) + '\n')
metrics = {}
for res in final_results:
b = res['benchmark']
if b not in metrics: metrics[b] = {'correct': [], 'steps': []}
metrics[b]['correct'].append(res['correct'])
metrics[b]['steps'].append(res['think_steps'])
print("\n" + "="*50)
print(f"FINAL SCORES (Random Sample N={SAMPLES_PER_BENCHMARK})")
print("="*50)
for b, d in metrics.items():
acc = sum(d['correct']) / len(d['correct']) * 100
avg_steps = sum(d['steps']) / len(d['steps'])
print(f"{b.upper():<10} | Acc: {acc:.2f}% | Avg Steps: {avg_steps:.1f} | N: {len(d['correct'])}")
print(f"Total time: {time.time() - start_time:.2f}s")
if __name__ == "__main__":
mp.set_start_method('spawn', force=True)
run_evaluation()