import argparse import re import torch import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from openrlhf.models import get_llm_for_sequence_regression from openrlhf.utils import get_tokenizer from openrlhf.utils.logging_utils import init_logger logger = init_logger(__name__) class RewardModelProxy: def __init__(self, args): # Modify the reward_model to your remote model self.reward_model = get_llm_for_sequence_regression( args.reward_pretrain, "reward", normalize_reward=args.normalize_reward, use_flash_attention_2=args.flash_attn, bf16=args.bf16, load_in_4bit=args.load_in_4bit, value_head_prefix=args.value_head_prefix, device_map="auto", ) self.reward_model.eval() self.tokenizer = get_tokenizer( args.reward_pretrain, self.reward_model, "left", None, use_fast=not args.disable_fast_tokenizer ) self.max_length = args.max_len self.batch_size = args.batch_size def get_reward(self, queries): # if self.batch_size is None: # batch_size = len(queries) # else: # batch_size = self.batch_size # logger.info(f"queries[0]: {queries[0]}") correct_count = 0 # For accuracy calculation total_count = len(df) allscores = [] # batch for index, row in df.iterrows(): chosen_query = row["chosen_prompt"] + " " + row["chosen"] reject_query = row["chosen_prompt"] + " " + row["reject"] # Get reward scores for chosen and reject scores = self.compare_queries(chosen_query, reject_query) all_scores.append(scores) # Determine correctness based on comparison chosen_score, reject_score = scores if chosen_score > reject_score: correct_count += 1 accuracy = correct_count / total_count if total_count > 0 else 0 print(f"Current Accuracy: {accuracy * 100:.2f}%") return all_scores, accuracy # with torch.no_grad(): # for i in range(0, len(queries), batch_size): # inputs = self.tokenize_fn( # queries[i : min(len(queries), i + batch_size)], device=self.reward_model.device # ) # r = self.reward_model(inputs["input_ids"], inputs["attention_mask"]) # r = r.tolist() # scores.extend(r) # return scores def compare_queries(self, chosen_query, reject_query): """ Compare the reward scores for chosen_query and reject_query. :param chosen_query: The query with the 'chosen' answer :param reject_query: The query with the 'reject' answer :return: Tuple (chosen_score, reject_score) """ with torch.no_grad(): inputs_chosen = self.tokenize_fn([chosen_query], device=self.reward_model.device) inputs_reject = self.tokenize_fn([reject_query], device=self.reward_model.device) chosen_score = self.reward_model(inputs_chosen["input_ids"], inputs_chosen["attention_mask"]).tolist()[0] reject_score = self.reward_model(inputs_reject["input_ids"], inputs_reject["attention_mask"]).tolist()[0] return chosen_score, reject_score def tokenize_fn(self, texts, device): batch = self.tokenizer( texts, return_tensors="pt", max_length=self.max_length, padding=True, truncation=True, ) return {k: v.to(device) for k, v in batch.items()} if __name__ == "__main__": parser = argparse.ArgumentParser() # Reward Model parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path") parser.add_argument("--normalize_reward", action="store_true", default=False, help="Enable Reward Normazation") parser.add_argument("--value_head_prefix", type=str, default="value_head") parser.add_argument("--max_len", type=int, default="2048") parser.add_argument("--port", type=int, default=5000, help="Port number for the server") parser.add_argument("--host", type=str, default="0.0.0.0", help="IP for the server") # Performance parser.add_argument("--load_in_4bit", action="store_true", default=False) parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") parser.add_argument( "--attn_implementation", type=str, default="flash_attention_2", help="Attention implementation (e.g., eager, flash_attention_2, flash_attention_3, kernels-community/vllm-flash-attn3)", ) parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) parser.add_argument("--batch_size", type=int, default=None) args = parser.parse_args() # server reward_model = RewardModelProxy(args) app = FastAPI() @app.post("/get_reward") async def get_reward(request: Request): data = await request.json() queries = data.get("query") rewards = reward_model.get_reward(queries) result = {"rewards": rewards, "scores": rewards, "extra_logs": {"dummy_scores": rewards}} logger.info(f"Sent JSON: {result}") return JSONResponse(result) uvicorn.run(app, host=args.host, port=args.port, log_level="info")