File size: 534 Bytes
7feac49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import List, Union

from ..argument import MegatronTrainArguments
from .sft import MegatronSft


class MegatronPt(MegatronSft):
    args_class = MegatronTrainArguments
    args: args_class

    def _prepare_template(self) -> None:
        self.args.use_chat_template = False
        super()._prepare_template()
        self.template.loss_scale = 'all'


def megatron_pt_main(args: Union[List[str], MegatronTrainArguments, None] = None):
    return MegatronPt(args).main()