Student0809's picture
Add files using upload-large-folder tool
cb2428f verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import time
from typing import List, Union
import json
from swift.llm import SamplingArguments, SwiftPipeline, load_dataset
from swift.utils import get_logger
logger = get_logger()
class SwiftSampling(SwiftPipeline):
args_class = SamplingArguments
args: args_class
def __init__(self, args: Union[List[str], SamplingArguments, None] = None) -> None:
super().__init__(args)
self.args.save_args()
os.makedirs(self.args.output_dir, exist_ok=True)
self.cur_piece = 0
self.total_piece = 1
if self.args.data_range:
self.cur_piece, self.total_piece = self.args.data_range
if self.args.sampler_type == 'sample':
from swift.llm.sampling.vanilla_sampler import VanillaSampler
self.sampler = VanillaSampler(self.args)
elif self.args.sampler_type == 'mcts':
from swift.llm.sampling.mcts import MctsSampler
self.sampler = MctsSampler(self.args)
elif self.args.sampler_type == 'distill':
from swift.llm.sampling.distill_sampler import DistillSampler
self.sampler = DistillSampler(self.args)
else:
raise ValueError(f'Unsupported sampler type: {self.args.sampler_type}')
def _get_dataset(self):
args = self.args
dataset_kwargs = args.get_dataset_kwargs()
sampling_dataset, _ = load_dataset(
args.dataset, split_dataset_ratio=0., shuffle=args.dataset_shuffle, **dataset_kwargs)
logger.info(f'Sampling_dataset: {sampling_dataset}')
dataset_len = len(sampling_dataset)
piece_len = dataset_len // self.total_piece
sampling_dataset = sampling_dataset.select(range(piece_len * self.cur_piece, piece_len * (self.cur_piece + 1)))
return sampling_dataset
def run(self):
os.makedirs(self.args.output_dir, exist_ok=True)
iter_file = os.path.join(self.args.output_dir, self.args.output_file)
resume_file = os.path.join(self.args.output_dir, self.args.output_file + '.resume')
tmp_file = os.path.join(self.args.output_dir, self.args.output_file + '.tmp')
ckpt_state_file = os.path.join(self.args.output_dir, 'ckpt_state.json')
if os.path.exists(iter_file) and not self.args.override_exist_file:
return
index_resume = -1
write_mode = 'w'
if self.args.resume:
write_mode = 'a'
if os.path.exists(resume_file):
shutil.copyfile(resume_file, tmp_file)
if os.path.exists(ckpt_state_file):
with open(ckpt_state_file, 'r') as ckpt_state:
data = json.load(ckpt_state)
index_resume = data.get('index', -1)
logger.info(f'Loaded index_resume: {index_resume}')
else:
if os.path.exists(tmp_file):
os.remove(tmp_file)
dataset = self._get_dataset()
dataset_len = len(dataset)
total_iters = int(dataset_len // self.args.num_sampling_per_gpu_batch_size)
if self.args.num_sampling_per_gpu_batches is None or self.args.num_sampling_per_gpu_batches > total_iters:
self.args.num_sampling_per_gpu_batches = total_iters
with open(tmp_file, write_mode) as f:
for _index in range(self.args.num_sampling_per_gpu_batches):
if _index <= index_resume:
continue
logger.info(f' Sampling index:{_index}')
slices = dataset[self.args.num_sampling_per_gpu_batch_size
* _index:self.args.num_sampling_per_gpu_batch_size * (_index + 1)]
slices = self.sampler.truncate_input(slices)
generated = self.sampler.do_sample(slices)
f.writelines(generated)
f.flush()
shutil.copy(tmp_file, resume_file)
with open(ckpt_state_file, 'w') as ckpt_state:
json.dump({'index': _index}, ckpt_state)
if os.path.exists(iter_file):
shutil.move(iter_file, iter_file + '.' + str(int(time.time())))
shutil.move(resume_file, iter_file)
logger.info(f'Sample file {iter_file} generated.')
def sampling_main(args: Union[List[str], SamplingArguments, None] = None):
return SwiftSampling(args).main()