Upload ms-swift/examples/sampler/mcts/mcts.py with huggingface_hub
Browse files
ms-swift/examples/sampler/mcts/mcts.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import time
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
from modelscope.msdatasets import MsDataset
|
| 8 |
+
|
| 9 |
+
conda_prefix = ''
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def client_sample(model: str, orm: str, dataset_path: str, iter: int, device_count: int, output_dir: str):
|
| 13 |
+
handlers = []
|
| 14 |
+
# Sampling cache
|
| 15 |
+
api_key = os.getenv('DASHSCOPE_API_KEY')
|
| 16 |
+
|
| 17 |
+
for device in range(device_count):
|
| 18 |
+
|
| 19 |
+
output_file = f'iter_{iter}_proc_{device}.jsonl'
|
| 20 |
+
cache_file = f'iter_{iter}_proc_{device}_cache.jsonl'
|
| 21 |
+
dataset = f'train_{device:02}.jsonl'
|
| 22 |
+
|
| 23 |
+
# output_file_path = os.path.join(output_dir, output_file)
|
| 24 |
+
cache_file_path = os.path.join(output_dir, cache_file)
|
| 25 |
+
single_dataset_path = os.path.join(dataset_path, dataset)
|
| 26 |
+
|
| 27 |
+
if not os.path.exists(cache_file_path):
|
| 28 |
+
open(cache_file_path, 'w').close()
|
| 29 |
+
sample_cmd = (f'USE_OPENCOMPASS_EVALUATOR=True '
|
| 30 |
+
f'swift sample '
|
| 31 |
+
f'--model {model} '
|
| 32 |
+
f'--orm_model {orm} '
|
| 33 |
+
f'--sampler_type mcts '
|
| 34 |
+
f'--process_reward_rate 0 '
|
| 35 |
+
f'--stop_words ки '
|
| 36 |
+
f'--seed 42 '
|
| 37 |
+
f'--api_key {api_key} '
|
| 38 |
+
f'--dataset {single_dataset_path} '
|
| 39 |
+
f'--max_length 2048 '
|
| 40 |
+
f'--system ./scripts/sampler/system_prompt.txt '
|
| 41 |
+
f'--load_args false '
|
| 42 |
+
f'--sampler_engine client '
|
| 43 |
+
f'--max_new_tokens 768 '
|
| 44 |
+
f'--override_exist_file true '
|
| 45 |
+
f'--num_sampling_per_gpu_batch_size 1 '
|
| 46 |
+
f'--num_return_sequences 8 '
|
| 47 |
+
f'--exploration_rate 0.2 '
|
| 48 |
+
f'--max_iterations 200 '
|
| 49 |
+
f'--output_dir {output_dir} '
|
| 50 |
+
f'--cache_files {cache_file} '
|
| 51 |
+
f'--output_file {output_file} '
|
| 52 |
+
f'--temperature 1.0 ')
|
| 53 |
+
print(f'Sampling caches of iter {iter}, part {device}.', flush=True)
|
| 54 |
+
# env['CUDA_VISIBLE_DEVICES'] = str(device)
|
| 55 |
+
handler = subprocess.Popen(
|
| 56 |
+
f'{sample_cmd}' + f' > mcts_logs/sample_iter_{iter}_proc_{device}_cache.log 2>&1',
|
| 57 |
+
env=os.environ.copy(),
|
| 58 |
+
shell=True,
|
| 59 |
+
executable='/bin/bash')
|
| 60 |
+
handlers.append(handler)
|
| 61 |
+
|
| 62 |
+
datasets = []
|
| 63 |
+
for proc, handler in enumerate(handlers):
|
| 64 |
+
handler.wait()
|
| 65 |
+
assert os.path.exists(os.path.join(output_dir, f'iter_{iter}_proc_{proc}.jsonl'))
|
| 66 |
+
datasets.append(os.path.join('sample_output', f'iter_{iter}_proc_{proc}.jsonl'))
|
| 67 |
+
print(f'Sampling done, files:{datasets}', flush=True)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def split_dataset(ds, split_size, out_path):
|
| 71 |
+
data_size = int(len(ds) / split_size) + 1
|
| 72 |
+
|
| 73 |
+
for i in range(split_size):
|
| 74 |
+
file_name = f'train_{i:02}.jsonl'
|
| 75 |
+
file_path = os.path.join(out_path, file_name)
|
| 76 |
+
print(file_path)
|
| 77 |
+
ds_split = ds[data_size * i:min(data_size * (i + 1), len(ds))]
|
| 78 |
+
print(f"split_size: {len(ds_split['problem'])}")
|
| 79 |
+
with open(file_path, 'w', encoding='utf-8') as file:
|
| 80 |
+
for problem, solution in zip(ds_split['problem'], ds_split['solution']):
|
| 81 |
+
message = {
|
| 82 |
+
'messages': [
|
| 83 |
+
{
|
| 84 |
+
'role': 'user',
|
| 85 |
+
'content': problem,
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
'role': 'assistant',
|
| 89 |
+
'content': solution,
|
| 90 |
+
},
|
| 91 |
+
]
|
| 92 |
+
}
|
| 93 |
+
file.write(json.dumps(message, ensure_ascii=False) + '\n')
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def main():
|
| 97 |
+
server_model = 'qwen-max'
|
| 98 |
+
orm = 'math'
|
| 99 |
+
device_count = 20
|
| 100 |
+
output_dir = 'output/sampler/client_mcts/'
|
| 101 |
+
dataset_dir = 'datasets/competition_math/'
|
| 102 |
+
log_dir = 'mcts_logs/'
|
| 103 |
+
|
| 104 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 105 |
+
os.makedirs(dataset_dir, exist_ok=True)
|
| 106 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 107 |
+
ds = MsDataset.load('tastelikefeet/competition_math', subset_name='default', split='train')
|
| 108 |
+
split_dataset(ds, device_count, dataset_dir)
|
| 109 |
+
|
| 110 |
+
ts = time.time()
|
| 111 |
+
client_sample(server_model, orm, dataset_dir, 0, device_count, output_dir)
|
| 112 |
+
print(f'do sample cost: {(time.time() - ts) / 60:.1f} minutes.', flush=True)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if __name__ == '__main__':
|
| 116 |
+
main()
|