Student0809's picture
Add files using upload-large-folder tool
cb2428f verified
from typing import Any, Dict, List
from swift.llm import SamplingArguments
from swift.plugin import orms, prms
from swift.utils import get_logger
logger = get_logger()
class Sampler:
def __init__(self, input_args: SamplingArguments):
self.args = input_args
self.template = None
self.processor = None
self.prm_model = None
self.orm_model = None
self._prepare_model_tokenizer()
self._prepare_template()
self._prepare_rm()
def _prepare_model_tokenizer(self):
args = self.args
_, self.processor = args.get_model_processor(load_model=False)
def _prepare_rm(self):
if self.args.prm_model is None:
self.prm_model = None
logger.warning('prm_model is None.')
elif self.args.prm_model in prms:
self.prm_model = prms[self.args.prm_model]()
else:
from swift.llm import PtEngine
self.prm_model = PtEngine(self.args.prm_model, max_batch_size=64)
if self.args.orm_model is None:
self.orm_model = None
logger.warning('orm_model is None.')
elif self.args.orm_model in orms:
self.orm_model = orms[self.args.orm_model]()
else:
from swift.llm import PtEngine
self.orm_model = PtEngine(self.args.orm_model, max_batch_size=64)
def _prepare_template(self) -> None:
template = self.args.get_template(self.processor)
self.template = template
self.template.set_mode('train')
def truncate_input(self, slices: List[Dict[str, Any]]):
"""Truncate the input rows to avoid hitting the max length of the policy model"""
return slices
def do_sample(self, data):
raise NotImplementedError