from typing import Any, Dict, List from swift.llm import SamplingArguments from swift.plugin import orms, prms from swift.utils import get_logger logger = get_logger() class Sampler: def __init__(self, input_args: SamplingArguments): self.args = input_args self.template = None self.processor = None self.prm_model = None self.orm_model = None self._prepare_model_tokenizer() self._prepare_template() self._prepare_rm() def _prepare_model_tokenizer(self): args = self.args _, self.processor = args.get_model_processor(load_model=False) def _prepare_rm(self): if self.args.prm_model is None: self.prm_model = None logger.warning('prm_model is None.') elif self.args.prm_model in prms: self.prm_model = prms[self.args.prm_model]() else: from swift.llm import PtEngine self.prm_model = PtEngine(self.args.prm_model, max_batch_size=64) if self.args.orm_model is None: self.orm_model = None logger.warning('orm_model is None.') elif self.args.orm_model in orms: self.orm_model = orms[self.args.orm_model]() else: from swift.llm import PtEngine self.orm_model = PtEngine(self.args.orm_model, max_batch_size=64) def _prepare_template(self) -> None: template = self.args.get_template(self.processor) self.template = template self.template.set_mode('train') def truncate_input(self, slices: List[Dict[str, Any]]): """Truncate the input rows to avoid hitting the max length of the policy model""" return slices def do_sample(self, data): raise NotImplementedError