import os import subprocess import time from typing import List import json from modelscope.msdatasets import MsDataset conda_prefix = '' def client_sample(model: str, orm: str, dataset_path: str, iter: int, device_count: int, output_dir: str): handlers = [] # Sampling cache api_key = os.getenv('DASHSCOPE_API_KEY') for device in range(device_count): output_file = f'iter_{iter}_proc_{device}.jsonl' cache_file = f'iter_{iter}_proc_{device}_cache.jsonl' dataset = f'train_{device:02}.jsonl' # output_file_path = os.path.join(output_dir, output_file) cache_file_path = os.path.join(output_dir, cache_file) single_dataset_path = os.path.join(dataset_path, dataset) if not os.path.exists(cache_file_path): open(cache_file_path, 'w').close() sample_cmd = (f'USE_OPENCOMPASS_EVALUATOR=True ' f'swift sample ' f'--model {model} ' f'--orm_model {orm} ' f'--sampler_type mcts ' f'--process_reward_rate 0 ' f'--stop_words ΠΊΠΈ ' f'--seed 42 ' f'--api_key {api_key} ' f'--dataset {single_dataset_path} ' f'--max_length 2048 ' f'--system ./scripts/sampler/system_prompt.txt ' f'--load_args false ' f'--sampler_engine client ' f'--max_new_tokens 768 ' f'--override_exist_file true ' f'--num_sampling_per_gpu_batch_size 1 ' f'--num_return_sequences 8 ' f'--exploration_rate 0.2 ' f'--max_iterations 200 ' f'--output_dir {output_dir} ' f'--cache_files {cache_file} ' f'--output_file {output_file} ' f'--temperature 1.0 ') print(f'Sampling caches of iter {iter}, part {device}.', flush=True) # env['CUDA_VISIBLE_DEVICES'] = str(device) handler = subprocess.Popen( f'{sample_cmd}' + f' > mcts_logs/sample_iter_{iter}_proc_{device}_cache.log 2>&1', env=os.environ.copy(), shell=True, executable='/bin/bash') handlers.append(handler) datasets = [] for proc, handler in enumerate(handlers): handler.wait() assert os.path.exists(os.path.join(output_dir, f'iter_{iter}_proc_{proc}.jsonl')) datasets.append(os.path.join('sample_output', f'iter_{iter}_proc_{proc}.jsonl')) print(f'Sampling done, files:{datasets}', flush=True) def split_dataset(ds, split_size, out_path): data_size = int(len(ds) / split_size) + 1 for i in range(split_size): file_name = f'train_{i:02}.jsonl' file_path = os.path.join(out_path, file_name) print(file_path) ds_split = ds[data_size * i:min(data_size * (i + 1), len(ds))] print(f"split_size: {len(ds_split['problem'])}") with open(file_path, 'w', encoding='utf-8') as file: for problem, solution in zip(ds_split['problem'], ds_split['solution']): message = { 'messages': [ { 'role': 'user', 'content': problem, }, { 'role': 'assistant', 'content': solution, }, ] } file.write(json.dumps(message, ensure_ascii=False) + '\n') def main(): server_model = 'qwen-max' orm = 'math' device_count = 20 output_dir = 'output/sampler/client_mcts/' dataset_dir = 'datasets/competition_math/' log_dir = 'mcts_logs/' os.makedirs(output_dir, exist_ok=True) os.makedirs(dataset_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) ds = MsDataset.load('tastelikefeet/competition_math', subset_name='default', split='train') split_dataset(ds, device_count, dataset_dir) ts = time.time() client_sample(server_model, orm, dataset_dir, 0, device_count, output_dir) print(f'do sample cost: {(time.time() - ts) / 60:.1f} minutes.', flush=True) if __name__ == '__main__': main()