interactSpeech / swift /plugin /rm_plugin.py
Student0809's picture
Add files using upload-large-folder tool
7feac49 verified
import re
import textwrap
from copy import deepcopy
from typing import Dict, List
import torch
from swift.llm import PtEngine, RequestConfig, Template, to_device
from swift.llm.infer.protocol import ChatCompletionResponse
from swift.utils import get_logger
logger = get_logger()
class DefaultRMPlugin:
"""
Default Reward Model Plugin
This class implements the default processing logic for reward models.
It assumes that `self.model` is a classification model with a value head(output dimmension 1).
The first logits value from the model's output is used as the reward score.
"""
def __init__(self, model, template):
self.model = model
self.template: Template = template
def __call__(self, inputs):
batched_inputs = [self.template.encode(deepcopy(infer_request)) for infer_request in inputs]
reward_inputs = to_device(self.template.data_collator(batched_inputs), self.model.device)
reward_inputs.pop('labels')
with torch.inference_mode():
return self.model(**reward_inputs).logits[:, 0]
class GenRMPlugin(DefaultRMPlugin):
def __init__(self, model, template):
"""
Generative Reward Model Plugin Example.
This method sets up the reward model plugin by initializing the PtEngine for efficient inference,
configuring the request parameters, and defining the system prompt that guides the reward model in
evaluating responses.
Args:
model (torch.nn.Module): The generative reward model.
template (Template): The template used for encoding input data.
"""
super().__init__(model, template)
# initilize PTEngine to infer
self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=0) # 0: no limit
self.request_config = RequestConfig() # customise your request config here
self.system = textwrap.dedent("""
Based on the dialogue history, analyze in detail whether the model's response is accurate, complete, and relevant.
Assign a reward score between 0 and 1, where 0 indicates completely incorrect and 1 indicates fully correct.
Before finishing your response, please assign a reward using the following format:
Reward: {reward}
For example:
Reward: 0.85
""") # noqa
def __call__(self, inputs):
"""
Compute reward scores for the provided inputs.
This method processes each input by converting dialogue messages into a query, sending the query to the
reward model for inference, and extracting the reward scores from the model's responses. The final reward
for each input is the average of all extracted scores.
Args:
inputs (List[Dict]): A list of input requests. Each input request is a dictionary containing:
- 'messages' (List[Dict]): messages from the training model. Each message dictionary includes:
- 'role' (str): The role of the speaker (e.g., 'user', 'assistant').
- 'content' (str): The content of the message.
- Additional dataset columns as key-value pairs (e.g., 'solutions', 'images').
Returns:
torch.Tensor: A tensor containing the average reward scores for each input. The tensor has a shape of (N,),
where N is the number of input requests.
"""
rm_inputs = self.prepare_rm_inputs(inputs)
results = self.engine.infer(rm_inputs, self.request_config, use_tqdm=False)
rewards = self.compute_rewards(results)
return torch.tensor(rewards, dtype=torch.float32)
def prepare_rm_inputs(self, inputs: List[Dict]) -> List[Dict]:
"""
Prepare inputs for the reward model by converting messages into queries.
Args:
inputs (List[Dict]): A list of input requests.
Returns:
List[Dict]: Processed inputs for the reward model.
"""
rm_inputs = []
for idx, infer_request in enumerate(inputs):
# Deep copy to prevent modification of original input
rm_infer_request = deepcopy(infer_request)
# Extract and convert messages to a single query string
messages = rm_infer_request.get('messages')
query = self.messages_to_query(messages)
# Construct new messages tailored for the reward model
rm_messages = [{'role': 'system', 'content': self.system}, {'role': 'user', 'content': query}]
# Update the messages in the reward infer request
rm_infer_request['messages'] = rm_messages
rm_inputs.append(rm_infer_request)
return rm_inputs
@staticmethod
def extract_reward(model_output: str) -> float:
"""
Extract the reward score from the model's output.
Args:
model_output (str): The model's output string, expected to follow the format "Reward: {reward}".
Returns:
float: The extracted reward score.
Raises:
ValueError: If the reward score cannot be extracted or the format is incorrect.
"""
match = re.search(r'Reward:\s*([0-1](?:\.\d+)?)', model_output)
if match:
return float(match.group(1))
else:
logger.warning("Unable to extract reward score from the model's output, set reward to 0")
return None
@staticmethod
def messages_to_query(messages):
"""
Compress a list of message dictionaries into a single query string.
Args:
messages (list[dict]): A list of message dictionaries, each containing:
- 'role' (str): The role of the speaker (e.g., 'user', 'assistant').
- 'content' (str): The content of the message.
Returns:
str: A single string that concatenates all messages in a formatted manner.
Example:
>>> messages = [
... {'role': 'user', 'content': 'Hello, how are you?'},
... {'role': 'assistant', 'content': 'I am fine, thank you! How can I assist you today?'},
... {'role': 'user', 'content': 'Can you help me with my homework?'}
... ]
>>> print(messages_to_query(messages))
User: Hello, how are you?
Assistant: I am fine, thank you! How can I assist you today?
User: Can you help me with my homework?
"""
# Initialize an empty list to hold formatted messages
formatted_messages = []
# Define a mapping for role capitalization if needed
role_mapping = {
'user': 'User',
'assistant': 'Assistant',
'system': 'System'
# Add more roles here as needed
}
for idx, message in enumerate(messages):
if not isinstance(message, dict):
raise TypeError(f'Each message must be a dictionary. Found {type(message)} at index {idx}.')
# Extract 'role' and 'content' from each message
role = message.get('role')
content = message.get('content')
if not content:
continue
# Capitalize the role using the mapping, default to capitalized original role
role_formatted = role_mapping.get(role.lower(), role.capitalize())
# Append the formatted message to the list
formatted_messages.append(f'{role_formatted}: {content}')
# Join all formatted messages with newline characters
query = '\n'.join(formatted_messages)
return query
def compute_rewards(self, results: List[ChatCompletionResponse]) -> List[float]:
"""
Compute average reward scores from the reward model's outputs.
Args:
results (List[ChatCompletionResponse]): A list of results from the reward model.
Returns:
List[float]: A list of average reward scores.
"""
rewards = []
for idx, output in enumerate(results):
try:
cur_rewards = []
for choice in output.choices:
response = choice.message.content
reward = self.extract_reward(response)
cur_rewards.append(reward)
cur_rewards = [r for r in cur_rewards if r is not None]
if cur_rewards:
average_reward = sum(cur_rewards) / len(cur_rewards)
else:
average_reward = 0.0
logger.warning('No valid rewards extracted. Assigning reward score of 0.0.')
rewards.append(average_reward)
except Exception as e:
logger.error(f'Error computing reward: {e}')
rewards.append(0.0) # Assign default reward score on failure
return rewards
rm_plugins = {
'default': DefaultRMPlugin,
'genrm': GenRMPlugin,
}