File size: 3,439 Bytes
0c51b93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import json
import os

import torch
from jinja2 import Environment, FileSystemLoader
from peft import PeftModelForSequenceClassification
from transformers import AutoModelForSequenceClassification, AutoTokenizer


def parse_args():
    parser = argparse.ArgumentParser(
        description="Test a reward model with a template and example data"
    )
    parser.add_argument("--model_path", type=str, required=True, help="Path to base model or HF model name")
    parser.add_argument("--adapter_path", type=str, required=True, help="Path to saved checkpoint directory")
    parser.add_argument("--template_path", type=str, required=True, help="Path to Jinja template file")
    parser.add_argument("--example_path", type=str, required=True, help="Path to example data JSON")
    return parser.parse_args()

def load_model_and_tokenizer(args):
    print(f"Loading base model: {args.model_path}")
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    print("Using full precision model")
    base_model = AutoModelForSequenceClassification.from_pretrained(
        args.model_path,
        device_map="auto",
        num_labels=1,  # For regression task
        pad_token_id=tokenizer.pad_token_id # very important to add this
    )


    adapter_path = os.path.join(args.adapter_path, 'adapter_model')
    if os.path.exists(adapter_path + '.safetensors') or os.path.exists(adapter_path + '.bin'):
        print(f"Loading adapter from: {args.adapter_path}")
        model = PeftModelForSequenceClassification.from_pretrained(base_model, args.adapter_path)
    else:
        print(f"No adapter found at {adapter_path}, using base model")
        model = base_model

    model.eval()
    return model, tokenizer

def load_template(template_path):
    template_dir = os.path.dirname(template_path)
    template_file = os.path.basename(template_path)

    if not template_dir:
        template_dir = "."

    env = Environment(loader=FileSystemLoader(template_dir))
    env.filters['tojson'] = lambda obj: json.dumps(obj)
    return env.get_template(template_file)

def evaluate_prompt(model, tokenizer, prompt):
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)

    device = next(model.parameters()).device

    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)

    # Get reward score directly from the logits
    reward = outputs.logits.squeeze().cpu().item()
    return reward

def main():
    args = parse_args()

    model, tokenizer = load_model_and_tokenizer(args)

    with open(args.example_path, 'r') as f:
        example_data = json.load(f)

    template = load_template(args.template_path)
    for i, example in enumerate(example_data):
        print(f"\n===== EXAMPLE {i+1}/{len(example_data)} =====")

        rendered_prompt = template.render(
            messages=[
                {"role": "user", "content": example['input']},
                {"role": "assistant", "content": example['output']},
            ],
            add_generation_prompt=False
        ).strip()

        reward = evaluate_prompt(model, tokenizer, rendered_prompt)
        gth_reward = example.get('value')

        print(f"REWARD SCORE: {reward:.6f}")
        if gth_reward is not None:
            print(f"GTH REWARD: {gth_reward:.6f}")
        else:
            print("GTH REWARD: Not available")

if __name__ == "__main__":
    main()