youqiwong's picture
Upload folder using huggingface_hub
0c51b93 verified
import os
import numpy as np
import requests
import torch
from jinja2 import Environment, FileSystemLoader
from peft import PeftModelForSequenceClassification
from transformers import AutoModelForSequenceClassification, AutoTokenizer
class RejectionSampler:
def __init__(
self,
sft_model_name,
sft_model_vllm_api_url,
reward_model_path,
reward_model_name,
template_path,
max_responses,
max_length,
sft_batch_size,
rm_batch_size,
):
self.max_responses = max_responses
self.sft_batch_size = sft_batch_size
self.rm_batch_size = rm_batch_size
self.sft_model_name = sft_model_name
self.max_length = max_length
self.sft_model_vllm_api_url = sft_model_vllm_api_url # Store vLLM API URL
self.reward_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(reward_model_name)
self.reward_model = self.load_reward_model(reward_model_path)
env = Environment(loader=FileSystemLoader("/".join(template_path.split("/")[:-1])))
self.template = env.get_template(template_path.split("/")[-1])
def load_reward_model(self, reward_model_path):
"""Load reward model with optional QLoRA quantization."""
print(f"Loading reward model: {reward_model_path}")
model_kwargs = {
#"torch_dtype": "auto", # very important
"torch_dtype": torch.float32,
"device_map": "auto",
"num_labels": 1 # For regression task
}
base_model = AutoModelForSequenceClassification.from_pretrained(
reward_model_path,
**model_kwargs
)
base_model.config.pad_token_id = self.tokenizer.pad_token_id
# Check and load the adapter if it exists
adapter_path = os.path.join(reward_model_path, 'adapter_model')
if os.path.exists(adapter_path + '.safetensors') or os.path.exists(adapter_path + '.bin'):
print(f"Loading reward adapter from: {reward_model_path}")
reward_model = PeftModelForSequenceClassification.from_pretrained(base_model, reward_model_path)
else:
print(f"No adapter found at {adapter_path}, using base model for reward")
reward_model = base_model
reward_model.eval() # Set to evaluation mode
return reward_model.to(self.reward_device)
def format_prompt(self, messages, add_generation_prompt=True):
"""Format the prompt using the template."""
if add_generation_prompt is True:
return self.template.render(
messages=messages,
add_generation_prompt=add_generation_prompt,
)
elif add_generation_prompt is False:
return self.template.render(
messages=messages,
add_generation_prompt=add_generation_prompt,
).strip()
def inference(self, messages, temperature, top_p, max_new_tokens):
"""Generate responses using vLLM API and select the best one based on reward model scores."""
prompt = self.format_prompt(messages, add_generation_prompt=True)
total_responses = []
total_responses_generated = 0
while total_responses_generated < self.max_responses:
current_batch_size = min(self.sft_batch_size, self.max_responses - total_responses_generated)
# Generate responses using vLLM API
payload = {
"prompt": prompt,
"model": self.sft_model_name,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_new_tokens,
"n": current_batch_size, # Number of completions to generate
"stop": [self.tokenizer.eos_token] if self.tokenizer.eos_token else None
}
try:
response = requests.post(self.sft_model_vllm_api_url, json=payload)
response.raise_for_status()
api_responses = response.json()
for completion in api_responses.get("choices", []):
if "text" in completion:
#completion["text"] = completion["text"].replace("```", "")
#completion["text"] = completion["text"].replace("```json", " ")
#completion["text"] = completion["text"].replace("json", " ")
#completion["text"] = completion["text"].strip()
total_responses.append(completion["text"])
for response in total_responses[total_responses_generated:]:
print(response)
total_responses_generated += current_batch_size
except Exception as e:
print(f"Error calling vLLM API: {e}")
break
if not total_responses:
return "Failed to generate responses from vLLM API."
messages_list = []
for response in total_responses:
messages_with_response = messages + [{'role': 'assistant', 'content': response}]
messages_list.append(messages_with_response)
rewards = self.inference_rm(messages_list)
print(rewards)
top_index = np.argmax(rewards)
top_response = total_responses[top_index]
return top_response if top_response is not None else "No valid responses found."
def inference_rm(self, messages_list):
"""Score responses using the reward model."""
rewards = []
for i in range(0, len(messages_list), self.rm_batch_size):
batch_messages = messages_list[i:i + self.rm_batch_size]
prompts = [
self.format_prompt(msgs, add_generation_prompt=False) for msgs in batch_messages
]
inputs = self.tokenizer(
prompts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length
).to(self.reward_device)
# check that there is no \n at the end of reward model inputs
#assert inputs['input_ids'][0][-1] == self.tokenizer.eos_token_id
with torch.no_grad():
outputs = self.reward_model(**inputs, return_dict=True)
batch_rewards = outputs.logits.squeeze().detach().cpu().numpy()
if batch_rewards.ndim == 0:
batch_rewards = [batch_rewards.item()]
rewards.extend(batch_rewards)
return rewards