File size: 534 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import List, Union
from ..argument import MegatronTrainArguments
from .sft import MegatronSft
class MegatronPt(MegatronSft):
args_class = MegatronTrainArguments
args: args_class
def _prepare_template(self) -> None:
self.args.use_chat_template = False
super()._prepare_template()
self.template.loss_scale = 'all'
def megatron_pt_main(args: Union[List[str], MegatronTrainArguments, None] = None):
return MegatronPt(args).main()
|