Student0809's picture
Add files using upload-large-folder tool
cb2428f verified
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy
import json
import numpy as np
from swift.llm import InferRequest, SamplingArguments
from swift.llm.infer.protocol import UsageInfo
from swift.utils import get_logger
from .base import Sampler
from .utils import get_reward, perform_infer
logger = get_logger()
NXT_PROMPT = """Continue.
"""
next_message = {
'role': 'user',
'content': NXT_PROMPT,
}
class LanguageNode:
def __init__(
self,
step: str = None,
sep_token: str = None,
parent: 'LanguageNode' = None,
):
self.parent = parent
if sep_token:
self.sep_token = sep_token
else:
self.sep_token = parent.sep_token
if parent:
self.path = parent.path[:] + [step]
self.answer = parent.answer + step + self.sep_token
self.depth = parent.depth + 1
else:
self.path = []
self.answer = ''
self.depth = 0
self.active_children = []
self.children = []
self.visit_count = 0
self.process_reward = 0.0
self.outcome_reward = 0.0
self.terminated = False
self.correct = False
def is_leaf(self):
return len(self.children) == 0
def is_root(self):
return self.parent is None
def visit(self):
self.visit_count += 1
def init_and_update_value(self, value):
self.outcome_reward = (self.outcome_reward * self.visit_count + value) / (self.visit_count + 1)
def add_child(self, child: 'LanguageNode'):
self.children.append(child)
if not child.terminated:
self.active_children.append(child)
def collect(self):
result = {
'path': self.path,
'depth': self.depth,
'visit_count': self.visit_count,
'process_reward': self.process_reward,
'outcome_reward': self.outcome_reward,
'terminated': str(self.terminated),
'correct': str(self.correct),
'children': [child.collect() for child in self.children],
}
return result
def __lt__(self, other):
return self.outcome_reward < other.outcome_reward
class MctsSampler(Sampler):
def __init__(self, input_args: SamplingArguments):
super().__init__(input_args)
self.usage_info = UsageInfo(0, 0, 0)
def _prepare_model_tokenizer(self):
args = self.args
self.infer_kwargs = {}
if args.sampler_engine == 'client':
from swift.llm import InferClient
api_key = args.api_key
base_url = args.base_url
self.infer_engine = [
InferClient(base_url=base_url, api_key=api_key) for _ in range(args.num_return_sequences)
]
self.infer_kwargs['model'] = args.model
else:
_Engine = self.get_infer_engine()
self.infer_engine = _Engine(self.args.model, model_type=self.args.model_type, **self.args.engine_kwargs)
def get_infer_engine(self):
if self.args.sampler_engine == 'pt':
from swift.llm import PtEngine
_Engine = PtEngine
elif self.args.sampler_engine == 'vllm':
from swift.llm import VllmEngine
_Engine = VllmEngine
elif self.args.sampler_engine == 'lmdeploy':
from swift.llm import LmdeployEngine
_Engine = LmdeployEngine
elif self.args.sampler_engine == 'no':
_Engine = None
else:
raise ValueError(f'Cannot find engine name: {self.args.sampler_engine}')
return _Engine
def _prepare_template(self) -> None:
# Hack from super()
self._prepare_request_configs()
def _prepare_request_configs(self):
_args = self.args
request_config = _args.get_request_config()
request_config.stop = _args.stop_words
request_config.seed = _args.seed
self.expand_request_configs = []
self.rollout_request_configs = []
for i in range(_args.num_return_sequences):
expand_request_config = deepcopy(request_config)
expand_request_config.n = 1
expand_request_config.num_beams = expand_request_config.n
expand_request_config.seed += i
self.expand_request_configs.append(expand_request_config)
rollout_request_config = deepcopy(request_config)
rollout_request_config.max_tokens = 500
rollout_request_config.temperature = 0.0
rollout_request_config.n = 1
self.rollout_request_configs.append(rollout_request_config)
def update_usage_info(self, response):
for key, value in self.usage_info.__dict__.items():
update_value = getattr(response.usage, key, None) + value
setattr(self.usage_info, key, update_value)
def search_single(self, query, ground_truth):
def _uct(uct_curr_node: LanguageNode):
alpha = _args.process_reward_rate
value = alpha * uct_curr_node.process_reward + (1 - alpha) * uct_curr_node.outcome_reward
if uct_curr_node.is_root():
return value
exploitation_score = value
exploration_score = (
_args.exploration_rate
* np.sqrt(np.log(uct_curr_node.parent.visit_count + 1) / (uct_curr_node.visit_count + 1)))
return exploration_score + exploitation_score
def _select(select_curr_node: LanguageNode):
while not select_curr_node.is_leaf():
select_curr_node = max(select_curr_node.active_children, key=lambda x: _uct(x))
return select_curr_node
def _expand(expand_curr_node: LanguageNode):
n = _args.num_return_sequences - len(expand_curr_node.children)
if expand_curr_node.is_root():
infer_requests = [InferRequest(system_message + [prompt_message]) for _ in range(n)]
else:
history_message = {
'role': 'assistant',
'content': expand_curr_node.answer,
}
infer_request = InferRequest(system_message + [prompt_message, history_message, next_message])
infer_requests = [infer_request for _ in range(n)]
# e_time = time.time()
# To perform the Expand operation in parallel,
# there's no need to consider the order for now, since the Prompt is the same.
expand_iter_index = 0
while True:
responses = perform_infer(self.infer_engine, infer_requests, self.expand_request_configs,
**self.infer_kwargs)
if len(responses) > 0:
break
if expand_iter_index == 5:
raise ValueError('Expand should not return any response')
expand_iter_index += 1
# logger.info(f"expand.expand time: {time.time() - e_time}")
# To fetch Outcome Reward in parallel,
# the Outcome-Reward obtained is returned in order, so they can be directly matched accordingly.
orm_infer_requests = []
unique_output = set()
for response in responses:
self.update_usage_info(response)
output = response.choices[0].message.content.rstrip(sep_token + '\n').split(sep_token)[0]
if output in unique_output:
continue
unique_output.add(output)
orm_infer_requests.append(InferRequest([{'role': 'assistant', 'content': output}]))
child = LanguageNode(step=output, parent=expand_curr_node)
if self.orm_model.check_terminate(child.answer)[0]:
child.terminated = True
expand_curr_node.add_child(child)
# e_time = time.time()
orm_score, _orm_mask = get_reward(
self.orm_model,
orm_infer_requests,
ground_truths=[ground_truth] * len(orm_infer_requests),
threshold=0.0)
# logger.info(f"expand.orm time: {time.time() - e_time}")
for child, score in zip(expand_curr_node.children, orm_score):
if child.terminated:
child.init_and_update_value(score)
child.correct = score > 0.9
terminated_nodes.append(child)
# e_time = time.time()
if self.prm_model:
prm_infer_requests = []
for child in expand_curr_node.children:
prm_message = {'role': 'assistant', 'content': child.answer}
prm_infer_requests.append(InferRequest([prompt_message, prm_message]))
prm_score, _prm_mask = get_reward(
self.prm_model,
prm_infer_requests,
ground_truths=[ground_truth] * len(prm_infer_requests),
threshold=0.0)
for child, score in zip(expand_curr_node.children, prm_score):
child.process_reward = score
# logger.info(f"expand.prm time: {time.time() - e_time}")
def _rollout(rollout_curr_node: LanguageNode):
rollout_depth = 0
rollout_nodes = {}
for i in range(len(rollout_curr_node.active_children)):
rollout_nodes[i] = {
'node': rollout_curr_node.active_children[i],
'history_messages': {
'role': 'assistant',
'content': rollout_curr_node.active_children[i].answer,
},
}
active_rollout_nodes = list(rollout_nodes.keys())
while len(active_rollout_nodes) > 0 and rollout_depth < _args.rollout_depth:
# r_time = time.time()
infer_requests = [
InferRequest(system_message
+ [prompt_message, rollout_nodes[index]['history_messages'], next_message])
for index in active_rollout_nodes
]
# logger.info(f"rollout.prepare time: {time.time() - r_time}")
# r_time = time.time()
rollout_iter_index = 0
while True:
responses = perform_infer(self.infer_engine, infer_requests, self.rollout_request_configs,
**self.infer_kwargs)
if len(responses) > 0:
break
if rollout_iter_index == 5:
raise ValueError('Rollout should not return any response')
rollout_iter_index += 1
# logger.info(f"rollout.infer time: {time.time() - r_time}")
# r_time = time.time()
orm_infer_requests = []
end_paths = []
for index, response in zip(active_rollout_nodes, responses):
self.update_usage_info(response)
output = response.choices[0].message.content.rstrip(sep_token
+ '\n').split(sep_token)[0] + sep_token + '\n'
rollout_nodes[index]['history_messages']['content'] += output
end_paths.append(rollout_nodes[index]['history_messages']['content'])
orm_infer_requests.append(InferRequest([rollout_nodes[index]['history_messages']]))
# logger.info(f"rollout.orm_prepare time: {time.time() - r_time}")
# r_time = time.time()
orm_score, _orm_mask = get_reward(
self.orm_model,
orm_infer_requests,
ground_truths=[ground_truth] * len(infer_requests),
threshold=0.0)
# logger.info(f"rollout.get_orm time: {time.time() - r_time}")
terminated_state = self.orm_model.check_terminate(end_paths)
for index, score, terminated in zip(active_rollout_nodes, orm_score, terminated_state):
if terminated:
rollout_curr_node.active_children[index].init_and_update_value(score)
if score > 0.9:
rollout_correct_answers.append(rollout_nodes[index]['history_messages']['content'])
else:
rollout_incorrect_answers.append(rollout_nodes[index]['history_messages']['content'])
rollout_nodes.pop(index)
active_rollout_nodes = list(rollout_nodes.keys())
rollout_depth += 1
def _back_propagate(back_curr_node: LanguageNode):
while back_curr_node:
if back_curr_node == curr_node:
best_child_value = max([child.outcome_reward for child in back_curr_node.children])
back_curr_node.init_and_update_value(best_child_value)
last_child_value = back_curr_node.outcome_reward
else:
back_curr_node.init_and_update_value(last_child_value)
last_child_value = back_curr_node.outcome_reward
back_curr_node.visit()
if len(back_curr_node.active_children) == 0:
back_curr_node.terminated = True
if not back_curr_node.is_root():
back_curr_node.parent.active_children.remove(back_curr_node)
back_curr_node = back_curr_node.parent
_args = self.args
system_message = [] + _args.system_message
sep_token = _args.stop_words[0] + '\n'
_root = LanguageNode(sep_token=sep_token)
prompt_message = {
'role': 'user',
'content': query,
}
rollout_correct_answers, rollout_incorrect_answers, terminated_nodes = [], [], []
iter_count = 0
stop_reason = None
while True:
logger.info(f'iter_count: {iter_count}' + '.' * 10)
s_time = time.time()
curr_node = _select(_root)
logger.debug('select' + '=' * 10 + f'time: {time.time() - s_time}')
s_time = time.time()
_expand(curr_node)
logger.debug('expand' + '=' * 10 + f'time: {time.time() - s_time}')
if curr_node.depth > _args.rollout_start_depth:
s_time = time.time()
_rollout(curr_node)
logger.debug('rollout' + '=' * 10 + f'time: {time.time() - s_time}')
s_time = time.time()
_back_propagate(curr_node)
logger.debug('back propagate' + '=' * 10 + f'time: {time.time() - s_time}')
if len(rollout_correct_answers) + len(rollout_incorrect_answers) >= 2 * _args.num_return_sequences:
if 4 * len(rollout_incorrect_answers) < len(rollout_correct_answers):
stop_reason = 'too easy'
break
elif 4 * len(rollout_correct_answers) < len(rollout_incorrect_answers):
stop_reason = 'too hard'
break
if _root.terminated:
stop_reason = 'root terminated'
break
if len(terminated_nodes) >= _args.num_return_sequences:
stop_reason = 'enough nodes'
break
if iter_count >= _args.max_iterations:
stop_reason = 'max_iterations'
break
iter_count += 1
logger.info(f'stop_reason: {stop_reason}')
# logger.info(f"rollout_correct_answers: {rollout_correct_answers}")
# logger.info(f"rollout_incorrect_answers: {rollout_incorrect_answers}")
monte_carlo_tree = _root.collect()
result = {
'query': query,
'ground_truth': ground_truth,
'rollout_correct_answers': rollout_correct_answers,
'rollout_incorrect_answers': rollout_incorrect_answers,
'monte_carlo_tree': monte_carlo_tree,
}
result_json = json.dumps(result, ensure_ascii=False)
logger.info(result_json)
return result_json
def do_sample(self, data):
if not isinstance(data, list):
data = [data]
generated = []
for item in data:
logger.info(f'time: {time.ctime(time.time())}')
try:
messages = item['messages'][0]
query = messages[0]['content']
ground_truth = messages[1]['content']
generated.append(self.search_single(query, ground_truth) + '\n')
except Exception as e:
logger.error(f'Error: {e}')
logger.error(f'Traceback: {traceback.format_exc()}')
return generated