Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,221 Bytes
56cfa73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
import json
import math
from dataclasses import asdict
from pathlib import Path
import hydra
import numpy as np
import pytorch_lightning as ptl
import torch
from omegaconf import DictConfig, ListConfig, OmegaConf
from safetensors.torch import save_file
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from transformers import get_cosine_schedule_with_warmup
from .model.config import TTSConfig
from .model.prediction_head import VelocityHead
from .tts import ARTTSModel
def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr):
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
progress = min((step - warmup_steps) / max(1, total_steps - warmup_steps), 1)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr
return lr_lambda
class TrainARTTS(ptl.LightningModule):
def __init__(
self,
config: TTSConfig,
quant_layer: list[int],
tie_embed: bool = False,
learning_rate: float = 5e-4,
end_learning_rate: float | None = None,
weight_decay: float = 0.1,
betas: tuple[float, float] = (0.9, 0.999),
n_warmup_steps: int = 500,
n_training_steps: int = 300000,
mask_text_p: float = 0.0,
load_weights: str | None = None,
stop_token_weight: float | None = None,
stop_loss_factor: float = 0.1,
stop_loss_warmup: tuple[int, int] | None = None,
):
super(TrainARTTS, self).__init__()
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.betas = betas
self.n_warmup_steps = n_warmup_steps
self.n_training_steps = n_training_steps
self.stop_token_weight = stop_token_weight
self.stop_loss_factor = stop_loss_factor
self.save_hyperparameters()
self.model = ARTTSModel(config)
if load_weights is not None:
model = torch.load(load_weights)
self.load_state_dict(model["state_dict"], strict=False)
def on_train_epoch_start(self):
if hasattr(self.trainer.train_dataloader.batch_sampler, "set_epoch"):
self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch)
def save_model_weights_and_config(
self,
dir: str | None,
model_filename: str = "model.st",
config_filename: str = "config.json",
):
def to_builtin(obj):
if isinstance(obj, dict):
return {k: to_builtin(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [to_builtin(v) for v in obj]
elif isinstance(obj, ListConfig):
return [to_builtin(v) for v in obj]
elif isinstance(obj, DictConfig):
return {k: to_builtin(v) for k, v in obj.items()}
else:
return obj
cfg = asdict(self.hparams.config)
cfg = to_builtin(cfg)
for k, v in cfg.items():
if v is ListConfig:
print("here")
cfg[k] = OmegaConf.to_container(v, resolve=True)
Path(dir).mkdir(exist_ok=True)
model_path = Path(dir) / model_filename
save_file(self.model.state_dict(), model_path)
with open(Path(dir) / config_filename, "w") as f:
json.dump(cfg, f, indent=2)
def step(self, batch, batch_idx: int, validation: bool = False):
text_token = batch["text_token"]
audio_token = batch["audio_token"].squeeze(2)
crossatt_mask = batch.get("crossatt_mask")
text_rel_pos = batch.get("text_rel_pos")
encoder_mask = batch.get("encoder_mask")
stop_token = batch.get("stop_token")
text_stop_token = batch.get("text_stop_token")
crossatt_rel_pos = batch.get("crossatt_rel_pos")
logits_mask = batch.get("y_mask")
pre_logits = self.model(
text_ids=text_token,
audio_inputs=audio_token,
text_mask=encoder_mask,
audio_mask=logits_mask,
crossatt_mask=crossatt_mask,
crossatt_rel_pos=crossatt_rel_pos,
stop_tokens=stop_token,
text_rel_pos=text_rel_pos,
text_stop_tokens=text_stop_token,
)
losses = {}
if validation and type(self.model.prediction_head) is DiffusionHead:
# deterministic time conditioning during validation
t = (
torch.ones(pre_logits.shape[0], device=pre_logits.device)
* batch_idx
/ self.trainer.num_val_batches[0]
)
losses |= self.model.prediction_head.compute_loss(
pre_logits,
audio_token[:, 1:],
mask=logits_mask[:, 1:] if logits_mask is not None else None,
t=t,
)
else:
losses |= self.model.prediction_head.compute_loss(
pre_logits,
audio_token[:, 1:],
mask=logits_mask[:, 1:] if logits_mask is not None else None,
)
if self.model.stop_prediction_head is not None and logits_mask is not None:
if stop_token is None:
stop_token = nn.functional.pad(
(~logits_mask)[:, 2:].to(pre_logits), (0, 1)
)
else:
stop_token = stop_token[:, 1:]
mask = logits_mask[:, 1:]
losses |= self.model.stop_prediction_head.compute_loss(
pre_logits[mask],
stop_token[mask],
)
return losses
def training_step(self, batch, idx):
losses = self.step(batch, idx)
total_loss = 0.0
for name, loss in losses.items():
self.log(f"train_{name}", loss, prog_bar=True, sync_dist=True)
if "stop" in name:
if self.hparams.stop_loss_warmup is not None:
alpha, beta = self.hparams.stop_loss_warmup
warmup = np.clip((idx - alpha) / beta, a_min=0.0, a_max=1.0)
else:
warmup = 1.0
loss *= self.stop_loss_factor * warmup
total_loss += loss
self.log("train_loss", total_loss, prog_bar=True, sync_dist=True)
return total_loss
def validation_step(self, batch, idx):
losses = self.step(batch, idx, validation=True)
total_loss = 0.0
for name, loss in losses.items():
self.log(f"val_{name}", loss, prog_bar=True, sync_dist=True)
total_loss += loss
self.log("val_loss", total_loss, prog_bar=True, sync_dist=True)
return total_loss
def configure_optimizers(self):
params = [
{
"params": self.model.parameters(),
"weight_decay": self.weight_decay,
}
]
opt = torch.optim.AdamW(
params,
lr=self.learning_rate,
betas=self.betas,
)
# scheduler = get_cosine_schedule_with_warmup(
# opt,
# num_warmup_steps=self.n_warmup_steps,
# num_training_steps=self.n_training_steps,
# )
scheduler = LambdaLR(
opt,
lr_lambda=cosine_schedule_with_warmup(
warmup_steps=self.hparams.n_warmup_steps,
total_steps=self.hparams.n_training_steps,
start_lr=self.hparams.learning_rate,
end_lr=self.hparams.learning_rate * 0.1,
),
)
return [opt], [{"scheduler": scheduler, "interval": "step"}]
@hydra.main(config_path="hydra_configs/", config_name="config", version_base="1.3")
def main(cfg: DictConfig):
ptl.seed_everything(cfg.seed_everything)
model = hydra.utils.instantiate(cfg.model)
cfg.experiment_name = f"ARTTS_{model.hparams.config.decoder_cfg.name}"
datamodule = hydra.utils.instantiate(cfg.data)
trainer = hydra.utils.instantiate(cfg.trainer)
trainer.fit(model, datamodule, ckpt_path=cfg.get("ckpt_path"))
if __name__ == "__main__":
main()
|