File size: 5,467 Bytes
d8a76be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
import re

import torch
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse

from openrlhf.models import get_llm_for_sequence_regression
from openrlhf.utils import get_tokenizer
from openrlhf.utils.logging_utils import init_logger

logger = init_logger(__name__)

class RewardModelProxy:
    def __init__(self, args):
        # Modify the reward_model to your remote model
        self.reward_model = get_llm_for_sequence_regression(
            args.reward_pretrain,
            "reward",
            normalize_reward=args.normalize_reward,
            use_flash_attention_2=args.flash_attn,
            bf16=args.bf16,
            load_in_4bit=args.load_in_4bit,
            value_head_prefix=args.value_head_prefix,
            device_map="auto",
        )
        self.reward_model.eval()

        self.tokenizer = get_tokenizer(
            args.reward_pretrain, self.reward_model, "left", None, use_fast=not args.disable_fast_tokenizer
        )
        self.max_length = args.max_len
        self.batch_size = args.batch_size

    def get_reward(self, queries):
        # if self.batch_size is None:
        #     batch_size = len(queries)
        # else:
        #     batch_size = self.batch_size

        # logger.info(f"queries[0]: {queries[0]}")
        correct_count = 0  # For accuracy calculation
        total_count = len(df)
        allscores = []
        # batch
        for index, row in df.iterrows():
            chosen_query = row["chosen_prompt"] + " " + row["chosen"]
            reject_query = row["chosen_prompt"] + " " + row["reject"]

            # Get reward scores for chosen and reject
            scores = self.compare_queries(chosen_query, reject_query)

            all_scores.append(scores)

            # Determine correctness based on comparison
            chosen_score, reject_score = scores
            if chosen_score > reject_score:
                correct_count += 1

        accuracy = correct_count / total_count if total_count > 0 else 0
        print(f"Current Accuracy: {accuracy * 100:.2f}%")
        return all_scores, accuracy
        # with torch.no_grad():
        #     for i in range(0, len(queries), batch_size):
        #         inputs = self.tokenize_fn(
        #             queries[i : min(len(queries), i + batch_size)], device=self.reward_model.device
        #         )
        #         r = self.reward_model(inputs["input_ids"], inputs["attention_mask"])
        #         r = r.tolist()
        #         scores.extend(r)
        # return scores
    def compare_queries(self, chosen_query, reject_query):
        """
        Compare the reward scores for chosen_query and reject_query.
        :param chosen_query: The query with the 'chosen' answer
        :param reject_query: The query with the 'reject' answer
        :return: Tuple (chosen_score, reject_score)
        """
        with torch.no_grad():
            inputs_chosen = self.tokenize_fn([chosen_query], device=self.reward_model.device)
            inputs_reject = self.tokenize_fn([reject_query], device=self.reward_model.device)

            chosen_score = self.reward_model(inputs_chosen["input_ids"], inputs_chosen["attention_mask"]).tolist()[0]
            reject_score = self.reward_model(inputs_reject["input_ids"], inputs_reject["attention_mask"]).tolist()[0]

        return chosen_score, reject_score

    def tokenize_fn(self, texts, device):
        batch = self.tokenizer(
            texts,
            return_tensors="pt",
            max_length=self.max_length,
            padding=True,
            truncation=True,
        )
        return {k: v.to(device) for k, v in batch.items()}


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Reward Model
    parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path")
    parser.add_argument("--normalize_reward", action="store_true", default=False, help="Enable Reward Normazation")
    parser.add_argument("--value_head_prefix", type=str, default="value_head")
    parser.add_argument("--max_len", type=int, default="2048")

    parser.add_argument("--port", type=int, default=5000, help="Port number for the server")
    parser.add_argument("--host", type=str, default="0.0.0.0", help="IP for the server")

    # Performance
    parser.add_argument("--load_in_4bit", action="store_true", default=False)
    parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16")
    parser.add_argument(
        "--attn_implementation",
        type=str,
        default="flash_attention_2",
        help="Attention implementation (e.g., eager, flash_attention_2, flash_attention_3, kernels-community/vllm-flash-attn3)",
    )
    parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False)
    parser.add_argument("--batch_size", type=int, default=None)

    args = parser.parse_args()

    # server
    reward_model = RewardModelProxy(args)
    app = FastAPI()

    @app.post("/get_reward")
    async def get_reward(request: Request):
        data = await request.json()
        queries = data.get("query")
        rewards = reward_model.get_reward(queries)
        result = {"rewards": rewards, "scores": rewards, "extra_logs": {"dummy_scores": rewards}}
        logger.info(f"Sent JSON: {result}")
        return JSONResponse(result)

    uvicorn.run(app, host=args.host, port=args.port, log_level="info")