|
|
|
|
|
import os |
|
|
import shutil |
|
|
import time |
|
|
from typing import List, Union |
|
|
|
|
|
import json |
|
|
|
|
|
from swift.llm import SamplingArguments, SwiftPipeline, load_dataset |
|
|
from swift.utils import get_logger |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
class SwiftSampling(SwiftPipeline): |
|
|
args_class = SamplingArguments |
|
|
args: args_class |
|
|
|
|
|
def __init__(self, args: Union[List[str], SamplingArguments, None] = None) -> None: |
|
|
super().__init__(args) |
|
|
self.args.save_args() |
|
|
os.makedirs(self.args.output_dir, exist_ok=True) |
|
|
self.cur_piece = 0 |
|
|
self.total_piece = 1 |
|
|
|
|
|
if self.args.data_range: |
|
|
self.cur_piece, self.total_piece = self.args.data_range |
|
|
|
|
|
if self.args.sampler_type == 'sample': |
|
|
from swift.llm.sampling.vanilla_sampler import VanillaSampler |
|
|
self.sampler = VanillaSampler(self.args) |
|
|
elif self.args.sampler_type == 'mcts': |
|
|
from swift.llm.sampling.mcts import MctsSampler |
|
|
self.sampler = MctsSampler(self.args) |
|
|
elif self.args.sampler_type == 'distill': |
|
|
from swift.llm.sampling.distill_sampler import DistillSampler |
|
|
self.sampler = DistillSampler(self.args) |
|
|
else: |
|
|
raise ValueError(f'Unsupported sampler type: {self.args.sampler_type}') |
|
|
|
|
|
def _get_dataset(self): |
|
|
args = self.args |
|
|
dataset_kwargs = args.get_dataset_kwargs() |
|
|
sampling_dataset, _ = load_dataset( |
|
|
args.dataset, split_dataset_ratio=0., shuffle=args.dataset_shuffle, **dataset_kwargs) |
|
|
logger.info(f'Sampling_dataset: {sampling_dataset}') |
|
|
dataset_len = len(sampling_dataset) |
|
|
piece_len = dataset_len // self.total_piece |
|
|
sampling_dataset = sampling_dataset.select(range(piece_len * self.cur_piece, piece_len * (self.cur_piece + 1))) |
|
|
return sampling_dataset |
|
|
|
|
|
def run(self): |
|
|
os.makedirs(self.args.output_dir, exist_ok=True) |
|
|
iter_file = os.path.join(self.args.output_dir, self.args.output_file) |
|
|
resume_file = os.path.join(self.args.output_dir, self.args.output_file + '.resume') |
|
|
tmp_file = os.path.join(self.args.output_dir, self.args.output_file + '.tmp') |
|
|
ckpt_state_file = os.path.join(self.args.output_dir, 'ckpt_state.json') |
|
|
if os.path.exists(iter_file) and not self.args.override_exist_file: |
|
|
return |
|
|
|
|
|
index_resume = -1 |
|
|
write_mode = 'w' |
|
|
if self.args.resume: |
|
|
write_mode = 'a' |
|
|
if os.path.exists(resume_file): |
|
|
shutil.copyfile(resume_file, tmp_file) |
|
|
|
|
|
if os.path.exists(ckpt_state_file): |
|
|
with open(ckpt_state_file, 'r') as ckpt_state: |
|
|
data = json.load(ckpt_state) |
|
|
index_resume = data.get('index', -1) |
|
|
logger.info(f'Loaded index_resume: {index_resume}') |
|
|
else: |
|
|
if os.path.exists(tmp_file): |
|
|
os.remove(tmp_file) |
|
|
|
|
|
dataset = self._get_dataset() |
|
|
dataset_len = len(dataset) |
|
|
total_iters = int(dataset_len // self.args.num_sampling_per_gpu_batch_size) |
|
|
|
|
|
if self.args.num_sampling_per_gpu_batches is None or self.args.num_sampling_per_gpu_batches > total_iters: |
|
|
self.args.num_sampling_per_gpu_batches = total_iters |
|
|
|
|
|
with open(tmp_file, write_mode) as f: |
|
|
for _index in range(self.args.num_sampling_per_gpu_batches): |
|
|
if _index <= index_resume: |
|
|
continue |
|
|
logger.info(f' Sampling index:{_index}') |
|
|
slices = dataset[self.args.num_sampling_per_gpu_batch_size |
|
|
* _index:self.args.num_sampling_per_gpu_batch_size * (_index + 1)] |
|
|
slices = self.sampler.truncate_input(slices) |
|
|
generated = self.sampler.do_sample(slices) |
|
|
f.writelines(generated) |
|
|
f.flush() |
|
|
shutil.copy(tmp_file, resume_file) |
|
|
with open(ckpt_state_file, 'w') as ckpt_state: |
|
|
json.dump({'index': _index}, ckpt_state) |
|
|
|
|
|
if os.path.exists(iter_file): |
|
|
shutil.move(iter_file, iter_file + '.' + str(int(time.time()))) |
|
|
shutil.move(resume_file, iter_file) |
|
|
logger.info(f'Sample file {iter_file} generated.') |
|
|
|
|
|
|
|
|
def sampling_main(args: Union[List[str], SamplingArguments, None] = None): |
|
|
return SwiftSampling(args).main() |
|
|
|