| |
| import dataclasses |
| import os |
| from dataclasses import dataclass |
| from datetime import datetime |
| from typing import List, Literal, Optional |
|
|
| import json |
|
|
| from swift.llm import BaseArguments |
| from swift.utils import get_logger |
|
|
| logger = get_logger() |
|
|
|
|
| @dataclass |
| class SamplingArguments(BaseArguments): |
| |
| prm_model: Optional[str] = None |
| orm_model: Optional[str] = None |
|
|
| |
| |
| sampler_type: Literal['sample', 'mcts', 'distill'] = 'sample' |
| sampler_engine: Literal['pt', 'lmdeploy', 'vllm', 'no', 'client'] = 'pt' |
| output_dir: str = 'sample_output' |
| output_file: Optional[str] = None |
| resume: bool = False |
| override_exist_file: bool = False |
| num_return_sequences: int = 64 |
| num_sampling_per_gpu_batch_size: int = 1 |
| num_sampling_per_gpu_batches: Optional[int] = None |
| n_best_to_keep: int = 5 |
| data_range: List[int] = dataclasses.field(default_factory=list) |
|
|
| |
| temperature: float = 1.0 |
| prm_threshold: float = 0.0 |
| easy_query_threshold: Optional[float] = None |
|
|
| |
| engine_kwargs: Optional[str] = None |
|
|
| |
| cache_files: List[str] = dataclasses.field(default_factory=list) |
|
|
| |
| rollout_depth: int = 5 |
| rollout_start_depth: int = 3 |
| max_iterations: int = 100 |
| process_reward_rate: float = 0.0 |
| exploration_rate: float = 0.5 |
| api_key: str = 'EMPTY' |
| base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1' |
|
|
| def _init_model_info(self): |
| if self.sampler_engine != 'client': |
| return super()._init_model_info() |
| self.task_type = 'causal_lm' |
| return |
|
|
| def __post_init__(self): |
| if self.output_file is None: |
| now = datetime.now() |
| formatted_time = now.strftime('%Y-%m-%d-%H-%M-%S') |
| self.output_file = formatted_time + '.jsonl' |
| logger.info(f'Setting output_file to {self.output_file}') |
| else: |
| if '/' in self.output_file or '\\' in self.output_file: |
| raise ValueError(f'Please use a string prefix without directory to ' |
| f'`--output_file` but now is: {self.output_file}') |
| self.padding_side = 'left' |
| if self.engine_kwargs is not None: |
| print(self.engine_kwargs) |
| self.engine_kwargs = json.loads(self.engine_kwargs) |
| else: |
| self.engine_kwargs = {} |
|
|
| super().__post_init__() |
|
|
| if self.system is not None: |
| self.system_message = [{ |
| 'role': 'system', |
| 'content': self.system, |
| }] |
| else: |
| self.system_message = [] |
|
|