dpe1/jules-tinyreasoner / src /compare_models.py
dpe1's picture
download
raw
2.69 kB
import torch
import os
import argparse
from src.tokenizer import CharTokenizer
from src.model import TinyReasonerModel
from src.sampler import Sampler
from src.prompts import get_random_prompt
from src.rewards import get_total_reward, reward_grounding
def compare_models(model_paths=None, num_samples=20, level=0):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = CharTokenizer()
if model_paths is None:
models = {
"SFT": "models/sft_model.pt",
"RL": "models/rl_model.pt"
}
else:
models = {os.path.basename(p): p for p in model_paths}
results = {}
for name, path in models.items():
if not os.path.exists(path):
print(f"Warning: {path} not found.")
continue
model = TinyReasonerModel(tokenizer.vocab_size).to(device)
model.load_state_dict(torch.load(path, map_location=device))
model.eval()
sampler = Sampler(model, tokenizer, device=device)
total_reward = 0
total_grounding = 0
correct_tasks = 0
print(f"\n--- Evaluating {name} Model ({path}) ---")
for i in range(num_samples):
prompt_text, ref_answer, task_type = get_random_prompt(level=level)
prompt = f"[BOS]{prompt_text}"
output = sampler.sample(prompt, max_len=256, temperature=0.7)
reward = get_total_reward(prompt_text, output, ref_answer, task_type)
grounding = reward_grounding(prompt_text, output)
total_reward += reward
total_grounding += grounding
# Count as grounded if grounding reward is positive
if grounding > 0:
correct_tasks += 1
if i < 3: # Print first 3 samples for inspection
print(f"P: {prompt_text}")
print(f"O: {output[:150]}...")
print(f"R: {reward:.2f}, G: {grounding:.2f}")
results[name] = {
"avg_reward": total_reward / num_samples,
"avg_grounding": total_grounding / num_samples,
"grounding_rate": correct_tasks / num_samples
}
for name, stats in results.items():
print(f"\nResults for {name}:")
for k, v in stats.items():
print(f" {k}: {v:.4f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--models", nargs="+", help="Paths to model checkpoints to compare")
parser.add_argument("--num_samples", type=int, default=20)
parser.add_argument("--level", type=int, default=0)
args = parser.parse_args()
compare_models(args.models, args.num_samples, args.level)

Xet Storage Details

Size:
2.69 kB
·
Xet hash:
7dddd76ab9e6fe7ff106a8c43eb97bb6024dc90e192a0ca3cd7869da4331893a

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.