File size: 6,568 Bytes
0c51b93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os

import numpy as np
import requests
import torch
from jinja2 import Environment, FileSystemLoader
from peft import PeftModelForSequenceClassification
from transformers import AutoModelForSequenceClassification, AutoTokenizer


class RejectionSampler:
    def __init__(
        self,
        sft_model_name,
        sft_model_vllm_api_url,
        reward_model_path,
        reward_model_name,
        template_path,
        max_responses,
        max_length,
        sft_batch_size,
        rm_batch_size,
    ):
        self.max_responses = max_responses
        self.sft_batch_size = sft_batch_size
        self.rm_batch_size = rm_batch_size
        self.sft_model_name = sft_model_name
        self.max_length = max_length
        self.sft_model_vllm_api_url = sft_model_vllm_api_url  # Store vLLM API URL

        self.reward_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.tokenizer = AutoTokenizer.from_pretrained(reward_model_name)
        self.reward_model = self.load_reward_model(reward_model_path)

        env = Environment(loader=FileSystemLoader("/".join(template_path.split("/")[:-1])))
        self.template = env.get_template(template_path.split("/")[-1])

    def load_reward_model(self, reward_model_path):
        """Load reward model with optional QLoRA quantization."""
        print(f"Loading reward model: {reward_model_path}")

        model_kwargs = {
            #"torch_dtype": "auto", # very important
            "torch_dtype": torch.float32,
            "device_map": "auto",
            "num_labels": 1  # For regression task
        }

        base_model = AutoModelForSequenceClassification.from_pretrained(
            reward_model_path,
            **model_kwargs
        )
        base_model.config.pad_token_id = self.tokenizer.pad_token_id

        # Check and load the adapter if it exists
        adapter_path = os.path.join(reward_model_path, 'adapter_model')
        if os.path.exists(adapter_path + '.safetensors') or os.path.exists(adapter_path + '.bin'):
            print(f"Loading reward adapter from: {reward_model_path}")
            reward_model = PeftModelForSequenceClassification.from_pretrained(base_model, reward_model_path)
        else:
            print(f"No adapter found at {adapter_path}, using base model for reward")
            reward_model = base_model

        reward_model.eval()  # Set to evaluation mode
        return reward_model.to(self.reward_device)

    def format_prompt(self, messages, add_generation_prompt=True):
        """Format the prompt using the template."""
        if add_generation_prompt is True:
            return self.template.render(
                messages=messages,
                add_generation_prompt=add_generation_prompt,
            )
        elif add_generation_prompt is False:
            return self.template.render(
                messages=messages,
                add_generation_prompt=add_generation_prompt,
            ).strip()

    def inference(self, messages, temperature, top_p, max_new_tokens):
        """Generate responses using vLLM API and select the best one based on reward model scores."""
        prompt = self.format_prompt(messages, add_generation_prompt=True)

        total_responses = []
        total_responses_generated = 0

        while total_responses_generated < self.max_responses:
            current_batch_size = min(self.sft_batch_size, self.max_responses - total_responses_generated)

            # Generate responses using vLLM API
            payload = {
                "prompt": prompt,
                "model": self.sft_model_name,
                "temperature": temperature,
                "top_p": top_p,
                "max_tokens": max_new_tokens,
                "n": current_batch_size,  # Number of completions to generate
                "stop": [self.tokenizer.eos_token] if self.tokenizer.eos_token else None
            }

            try:
                response = requests.post(self.sft_model_vllm_api_url, json=payload)
                response.raise_for_status()

                api_responses = response.json()
                for completion in api_responses.get("choices", []):
                    if "text" in completion:
                        #completion["text"] = completion["text"].replace("```", "")
                        #completion["text"] = completion["text"].replace("```json", " ")
                        #completion["text"] = completion["text"].replace("json", " ")
                        #completion["text"] = completion["text"].strip()

                        total_responses.append(completion["text"])

                for response in total_responses[total_responses_generated:]:
                    print(response)

                total_responses_generated += current_batch_size

            except Exception as e:
                print(f"Error calling vLLM API: {e}")
                break

        if not total_responses:
            return "Failed to generate responses from vLLM API."

        messages_list = []
        for response in total_responses:
            messages_with_response = messages + [{'role': 'assistant', 'content': response}]
            messages_list.append(messages_with_response)

        rewards = self.inference_rm(messages_list)
        print(rewards)

        top_index = np.argmax(rewards)
        top_response = total_responses[top_index]

        return top_response if top_response is not None else "No valid responses found."

    def inference_rm(self, messages_list):
        """Score responses using the reward model."""
        rewards = []
        for i in range(0, len(messages_list), self.rm_batch_size):
            batch_messages = messages_list[i:i + self.rm_batch_size]

            prompts = [
                self.format_prompt(msgs, add_generation_prompt=False) for msgs in batch_messages
            ]
            inputs = self.tokenizer(
                prompts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length
            ).to(self.reward_device)

            # check that there is no \n at the end of reward model inputs
            #assert inputs['input_ids'][0][-1] == self.tokenizer.eos_token_id

            with torch.no_grad():
                outputs = self.reward_model(**inputs, return_dict=True)

            batch_rewards = outputs.logits.squeeze().detach().cpu().numpy()

            if batch_rewards.ndim == 0:
                batch_rewards = [batch_rewards.item()]
            rewards.extend(batch_rewards)

        return rewards