File size: 5,467 Bytes
d8a76be | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | import argparse
import re
import torch
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from openrlhf.models import get_llm_for_sequence_regression
from openrlhf.utils import get_tokenizer
from openrlhf.utils.logging_utils import init_logger
logger = init_logger(__name__)
class RewardModelProxy:
def __init__(self, args):
# Modify the reward_model to your remote model
self.reward_model = get_llm_for_sequence_regression(
args.reward_pretrain,
"reward",
normalize_reward=args.normalize_reward,
use_flash_attention_2=args.flash_attn,
bf16=args.bf16,
load_in_4bit=args.load_in_4bit,
value_head_prefix=args.value_head_prefix,
device_map="auto",
)
self.reward_model.eval()
self.tokenizer = get_tokenizer(
args.reward_pretrain, self.reward_model, "left", None, use_fast=not args.disable_fast_tokenizer
)
self.max_length = args.max_len
self.batch_size = args.batch_size
def get_reward(self, queries):
# if self.batch_size is None:
# batch_size = len(queries)
# else:
# batch_size = self.batch_size
# logger.info(f"queries[0]: {queries[0]}")
correct_count = 0 # For accuracy calculation
total_count = len(df)
allscores = []
# batch
for index, row in df.iterrows():
chosen_query = row["chosen_prompt"] + " " + row["chosen"]
reject_query = row["chosen_prompt"] + " " + row["reject"]
# Get reward scores for chosen and reject
scores = self.compare_queries(chosen_query, reject_query)
all_scores.append(scores)
# Determine correctness based on comparison
chosen_score, reject_score = scores
if chosen_score > reject_score:
correct_count += 1
accuracy = correct_count / total_count if total_count > 0 else 0
print(f"Current Accuracy: {accuracy * 100:.2f}%")
return all_scores, accuracy
# with torch.no_grad():
# for i in range(0, len(queries), batch_size):
# inputs = self.tokenize_fn(
# queries[i : min(len(queries), i + batch_size)], device=self.reward_model.device
# )
# r = self.reward_model(inputs["input_ids"], inputs["attention_mask"])
# r = r.tolist()
# scores.extend(r)
# return scores
def compare_queries(self, chosen_query, reject_query):
"""
Compare the reward scores for chosen_query and reject_query.
:param chosen_query: The query with the 'chosen' answer
:param reject_query: The query with the 'reject' answer
:return: Tuple (chosen_score, reject_score)
"""
with torch.no_grad():
inputs_chosen = self.tokenize_fn([chosen_query], device=self.reward_model.device)
inputs_reject = self.tokenize_fn([reject_query], device=self.reward_model.device)
chosen_score = self.reward_model(inputs_chosen["input_ids"], inputs_chosen["attention_mask"]).tolist()[0]
reject_score = self.reward_model(inputs_reject["input_ids"], inputs_reject["attention_mask"]).tolist()[0]
return chosen_score, reject_score
def tokenize_fn(self, texts, device):
batch = self.tokenizer(
texts,
return_tensors="pt",
max_length=self.max_length,
padding=True,
truncation=True,
)
return {k: v.to(device) for k, v in batch.items()}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Reward Model
parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path")
parser.add_argument("--normalize_reward", action="store_true", default=False, help="Enable Reward Normazation")
parser.add_argument("--value_head_prefix", type=str, default="value_head")
parser.add_argument("--max_len", type=int, default="2048")
parser.add_argument("--port", type=int, default=5000, help="Port number for the server")
parser.add_argument("--host", type=str, default="0.0.0.0", help="IP for the server")
# Performance
parser.add_argument("--load_in_4bit", action="store_true", default=False)
parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16")
parser.add_argument(
"--attn_implementation",
type=str,
default="flash_attention_2",
help="Attention implementation (e.g., eager, flash_attention_2, flash_attention_3, kernels-community/vllm-flash-attn3)",
)
parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False)
parser.add_argument("--batch_size", type=int, default=None)
args = parser.parse_args()
# server
reward_model = RewardModelProxy(args)
app = FastAPI()
@app.post("/get_reward")
async def get_reward(request: Request):
data = await request.json()
queries = data.get("query")
rewards = reward_model.get_reward(queries)
result = {"rewards": rewards, "scores": rewards, "extra_logs": {"dummy_scores": rewards}}
logger.info(f"Sent JSON: {result}")
return JSONResponse(result)
uvicorn.run(app, host=args.host, port=args.port, log_level="info") |