yuccaaa commited on
Commit
c790bff
·
verified ·
1 Parent(s): 6e1f525

Upload ms-swift/examples/sampler/mcts/mcts.py with huggingface_hub

Browse files
ms-swift/examples/sampler/mcts/mcts.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import time
4
+ from typing import List
5
+
6
+ import json
7
+ from modelscope.msdatasets import MsDataset
8
+
9
+ conda_prefix = ''
10
+
11
+
12
+ def client_sample(model: str, orm: str, dataset_path: str, iter: int, device_count: int, output_dir: str):
13
+ handlers = []
14
+ # Sampling cache
15
+ api_key = os.getenv('DASHSCOPE_API_KEY')
16
+
17
+ for device in range(device_count):
18
+
19
+ output_file = f'iter_{iter}_proc_{device}.jsonl'
20
+ cache_file = f'iter_{iter}_proc_{device}_cache.jsonl'
21
+ dataset = f'train_{device:02}.jsonl'
22
+
23
+ # output_file_path = os.path.join(output_dir, output_file)
24
+ cache_file_path = os.path.join(output_dir, cache_file)
25
+ single_dataset_path = os.path.join(dataset_path, dataset)
26
+
27
+ if not os.path.exists(cache_file_path):
28
+ open(cache_file_path, 'w').close()
29
+ sample_cmd = (f'USE_OPENCOMPASS_EVALUATOR=True '
30
+ f'swift sample '
31
+ f'--model {model} '
32
+ f'--orm_model {orm} '
33
+ f'--sampler_type mcts '
34
+ f'--process_reward_rate 0 '
35
+ f'--stop_words ки '
36
+ f'--seed 42 '
37
+ f'--api_key {api_key} '
38
+ f'--dataset {single_dataset_path} '
39
+ f'--max_length 2048 '
40
+ f'--system ./scripts/sampler/system_prompt.txt '
41
+ f'--load_args false '
42
+ f'--sampler_engine client '
43
+ f'--max_new_tokens 768 '
44
+ f'--override_exist_file true '
45
+ f'--num_sampling_per_gpu_batch_size 1 '
46
+ f'--num_return_sequences 8 '
47
+ f'--exploration_rate 0.2 '
48
+ f'--max_iterations 200 '
49
+ f'--output_dir {output_dir} '
50
+ f'--cache_files {cache_file} '
51
+ f'--output_file {output_file} '
52
+ f'--temperature 1.0 ')
53
+ print(f'Sampling caches of iter {iter}, part {device}.', flush=True)
54
+ # env['CUDA_VISIBLE_DEVICES'] = str(device)
55
+ handler = subprocess.Popen(
56
+ f'{sample_cmd}' + f' > mcts_logs/sample_iter_{iter}_proc_{device}_cache.log 2>&1',
57
+ env=os.environ.copy(),
58
+ shell=True,
59
+ executable='/bin/bash')
60
+ handlers.append(handler)
61
+
62
+ datasets = []
63
+ for proc, handler in enumerate(handlers):
64
+ handler.wait()
65
+ assert os.path.exists(os.path.join(output_dir, f'iter_{iter}_proc_{proc}.jsonl'))
66
+ datasets.append(os.path.join('sample_output', f'iter_{iter}_proc_{proc}.jsonl'))
67
+ print(f'Sampling done, files:{datasets}', flush=True)
68
+
69
+
70
+ def split_dataset(ds, split_size, out_path):
71
+ data_size = int(len(ds) / split_size) + 1
72
+
73
+ for i in range(split_size):
74
+ file_name = f'train_{i:02}.jsonl'
75
+ file_path = os.path.join(out_path, file_name)
76
+ print(file_path)
77
+ ds_split = ds[data_size * i:min(data_size * (i + 1), len(ds))]
78
+ print(f"split_size: {len(ds_split['problem'])}")
79
+ with open(file_path, 'w', encoding='utf-8') as file:
80
+ for problem, solution in zip(ds_split['problem'], ds_split['solution']):
81
+ message = {
82
+ 'messages': [
83
+ {
84
+ 'role': 'user',
85
+ 'content': problem,
86
+ },
87
+ {
88
+ 'role': 'assistant',
89
+ 'content': solution,
90
+ },
91
+ ]
92
+ }
93
+ file.write(json.dumps(message, ensure_ascii=False) + '\n')
94
+
95
+
96
+ def main():
97
+ server_model = 'qwen-max'
98
+ orm = 'math'
99
+ device_count = 20
100
+ output_dir = 'output/sampler/client_mcts/'
101
+ dataset_dir = 'datasets/competition_math/'
102
+ log_dir = 'mcts_logs/'
103
+
104
+ os.makedirs(output_dir, exist_ok=True)
105
+ os.makedirs(dataset_dir, exist_ok=True)
106
+ os.makedirs(log_dir, exist_ok=True)
107
+ ds = MsDataset.load('tastelikefeet/competition_math', subset_name='default', split='train')
108
+ split_dataset(ds, device_count, dataset_dir)
109
+
110
+ ts = time.time()
111
+ client_sample(server_model, orm, dataset_dir, 0, device_count, output_dir)
112
+ print(f'do sample cost: {(time.time() - ts) / 60:.1f} minutes.', flush=True)
113
+
114
+
115
+ if __name__ == '__main__':
116
+ main()