File size: 5,738 Bytes
32d5b2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba2c5eb
32d5b2b
 
 
ba2c5eb
 
 
 
32d5b2b
ba2c5eb
32d5b2b
 
 
ba2c5eb
 
 
 
 
 
55ac664
ba2c5eb
32d5b2b
 
 
 
 
ba2c5eb
 
55ac664
ba2c5eb
 
 
 
 
 
 
a7dc8e9
ba2c5eb
 
a7dc8e9
ba2c5eb
 
 
 
 
 
 
 
 
 
 
32d5b2b
 
 
 
ca7dd21
32d5b2b
55ac664
ca7dd21
 
 
 
 
32d5b2b
b6ec358
 
 
 
55ac664
ca7dd21
 
 
 
55ac664
 
ca7dd21
b6ec358
ca7dd21
 
 
 
 
 
 
ba2c5eb
b6ec358
ca7dd21
b6ec358
ca7dd21
 
 
b6ec358
ca7dd21
 
 
 
5805255
 
ca7dd21
 
32d5b2b
ca7dd21
 
 
 
 
 
 
 
32d5b2b
 
 
a7dc8e9
ba2c5eb
 
32d5b2b
 
 
 
 
 
 
 
 
 
 
 
ba2c5eb
32d5b2b
 
ba2c5eb
32d5b2b
 
 
 
 
 
ba2c5eb
a7dc8e9
ba2c5eb
434855f
ba2c5eb
434855f
ba2c5eb
32d5b2b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
#               2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from contextlib import nullcontext

import torch
import torch.distributed as dist
from cosyvoice.utils.train_utils import (batch_backward, batch_forward,
                                         cosyvoice_join, log_per_save,
                                         log_per_step, save_model,
                                         update_parameter_and_lr)

from loguru import logger


class Executor:
    """Executor for training and cross validation"""
    def __init__(
        self,
        gan: bool = False,
        ref_model: torch.nn.Module = None,
        dpo_loss: torch.nn.Module = None,
        use_contrastive_fm: bool = False
    ):
        self.gan = gan
        self.ref_model = ref_model
        self.dpo_loss = dpo_loss
        self.step = 0
        self.epoch = 0
        self.rank = int(os.environ.get("RANK", 0))
        self.device = torch.device(f"cuda:{self.rank}")
        self.use_contrastive_fm = use_contrastive_fm

    def train_one_epoc(
        self,
        model,
        optimizer,
        scheduler,
        train_data_loader,
        experiment,
        info_dict,
        scaler,
        model_type
    ):
        """Train one epoch"""

        lr = optimizer.param_groups[0]["lr"]
        logger.info(
            f"Epoch {self.epoch} TRAIN info lr {lr} rank {self.rank}"
        )
        logger.info(
            f"using accumulate grad, new batch size is {info_dict['accum_grad']} times larger than before"
        )
       
        model.train()
        if self.ref_model is not None:
            self.ref_model.eval()

        use_ddp = info_dict["train_engine"] == "torch_ddp"


        for batch_idx, batch_dict in enumerate(train_data_loader):
            info_dict["tag"] = "TRAIN"
            info_dict["step"] = self.step
            info_dict["epoch"] = self.epoch
            info_dict["batch_idx"] = batch_idx

            for key, value in batch_dict.items():
                if isinstance(value, torch.Tensor):
                    print(f'{key} {value.shape}')


            if use_ddp and (batch_idx + 1) % info_dict["accum_grad"] != 0:
                context = model.no_sync
            else:
                context = nullcontext


            with context():
                logger.info(f'{self.rank} batch_forward')
                info_dict = batch_forward(
                    model,
                    batch_dict,
                    scaler,
                    info_dict,
                    ref_model=self.ref_model,
                    dpo_loss=self.dpo_loss,
                )
                logger.info(f'{self.rank} batch_backward')
                info_dict = batch_backward(model, scaler, info_dict)
            logger.info(f'{self.rank} update_parameter_and_lr')
            info_dict = update_parameter_and_lr(
                model, optimizer, scheduler, scaler, info_dict, model_type=model_type
            )
            logger.info(f'{self.rank} log_per_step')
            log_per_step(experiment, info_dict)

            if (
                info_dict.get("save_per_step", -1) > 0
                and (self.step) % info_dict["save_per_step"] == 0
                and (batch_idx) % info_dict["accum_grad"] == 0
            ):
                if dist.is_initialized():
                    dist.barrier()
                model_name = (
                    f"epoch_{self.epoch}_step_{self.step + 1}"
                )
                save_model(model, model_name, info_dict)
                model.train()

            if (batch_idx + 1) % info_dict["accum_grad"] == 0:
                self.step += 1
        dist.barrier()

    @torch.inference_mode()
    def cv(self, model, cv_data_loader, experiment, info_dict, on_batch_end=True):
        """Cross validation on"""
        logger.info(f"Epoch {self.epoch} Step {self.step + 1} on_batch_end {on_batch_end} CV rank {self.rank}")
        model.eval()
        total_num_utts, total_loss_dict = 0, {}  # avoid division by 0
        for batch_idx, batch_dict in enumerate(cv_data_loader):
            info_dict["tag"] = "CV"
            info_dict["step"] = self.step
            info_dict["epoch"] = self.epoch
            info_dict["batch_idx"] = batch_idx

            num_utts = len(batch_dict["utts"])
            total_num_utts += num_utts

            if self.gan is True:
                batch_dict["turn"] = "generator"
            info_dict = batch_forward(model, batch_dict, None, info_dict)

            for k, v in info_dict["loss_dict"].items():
                if k not in total_loss_dict:
                    total_loss_dict[k] = []
                total_loss_dict[k].append(v.item() * num_utts)
            log_per_step(None, info_dict)
        for k, v in total_loss_dict.items():
            total_loss_dict[k] = sum(v) / total_num_utts
        info_dict["loss_dict"] = total_loss_dict
        log_per_save(experiment, info_dict)
        model_name = (
            f"epoch_{self.epoch}_whole"
            if on_batch_end
            else f"epoch_{self.epoch}_step_{self.step + 1}"
        )
        save_model(model, model_name, info_dict)