File size: 1,785 Bytes
cb2428f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from typing import Any, Dict, List

from swift.llm import SamplingArguments
from swift.plugin import orms, prms
from swift.utils import get_logger

logger = get_logger()


class Sampler:

    def __init__(self, input_args: SamplingArguments):
        self.args = input_args
        self.template = None
        self.processor = None
        self.prm_model = None
        self.orm_model = None
        self._prepare_model_tokenizer()
        self._prepare_template()
        self._prepare_rm()

    def _prepare_model_tokenizer(self):
        args = self.args
        _, self.processor = args.get_model_processor(load_model=False)

    def _prepare_rm(self):
        if self.args.prm_model is None:
            self.prm_model = None
            logger.warning('prm_model is None.')
        elif self.args.prm_model in prms:
            self.prm_model = prms[self.args.prm_model]()
        else:
            from swift.llm import PtEngine
            self.prm_model = PtEngine(self.args.prm_model, max_batch_size=64)

        if self.args.orm_model is None:
            self.orm_model = None
            logger.warning('orm_model is None.')
        elif self.args.orm_model in orms:
            self.orm_model = orms[self.args.orm_model]()
        else:
            from swift.llm import PtEngine
            self.orm_model = PtEngine(self.args.orm_model, max_batch_size=64)

    def _prepare_template(self) -> None:
        template = self.args.get_template(self.processor)
        self.template = template
        self.template.set_mode('train')

    def truncate_input(self, slices: List[Dict[str, Any]]):
        """Truncate the input rows to avoid hitting the max length of the policy model"""
        return slices

    def do_sample(self, data):
        raise NotImplementedError