pardi-speech / tts /train_playhead.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
import math
import hydra
import pytorch_lightning as ptl
import torch
from omegaconf import DictConfig
from super_monotonic_align import maximum_path
from torch.optim.lr_scheduler import LambdaLR
from model.config import PlayHeadConfig
from playhead import PlayHead
from train_tts import TrainARTTS
def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr):
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr
return lr_lambda
def expand(x, r):
b, n, d = x.shape
return x.unsqueeze(2).repeat(1, 1, r, 1).reshape(b, r * n, d)
class TrainPlayHead(ptl.LightningModule):
def __init__(
self,
tts_checkpoint_path: str,
playhead_config: PlayHeadConfig,
learning_rate: float = 5e-4,
end_learning_rate: float | None = None,
weight_decay: float = 0.1,
betas: tuple[float, float] = (0.9, 0.999),
n_warmup_steps: int = 500,
n_training_steps: int = 300000,
):
super(TrainPlayHead, self).__init__()
cfg = playhead_config
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.betas = betas
self.n_warmup_steps = n_warmup_steps
self.n_training_steps = n_training_steps
self.selected_cross_attention_heads = cfg.selected_cross_attention_heads
self.avg_pool_stride = cfg.avg_pool_stride
self.target_lag = cfg.target_lag
self.save_hyperparameters()
self.model = PlayHead(playhead_config)
tts_lightning_module = TrainARTTS.load_from_checkpoint(tts_checkpoint_path)
self.tts_model = tts_lightning_module.model.eval()
for p in self.tts_model.parameters():
p.requires_grad = False
def on_train_epoch_start(self):
if hasattr(self.trainer.train_dataloader.batch_sampler, "set_epoch"):
self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch)
def save_model_weights_and_config(
self,
dir: str | None,
model_filename: str = "model.st",
config_filename: str = "config.json",
):
# cfg = self.hparams.config
# Path(dir).mkdir(exist_ok=True)
# model_path = Path(dir) / model_filename
# save_file(self.model.state_dict(), model_path)
# with open(Path(dir) / config_filename, "w") as f:
# json.dump(asdict(cfg), f, indent=2)
pass
def step(self, batch, batch_idx: int, validation: bool = False):
text_token = batch["text_token"]
audio_token = batch["audio_token"].squeeze(2)
crossatt_mask = batch["crossatt_mask"]
text_rel_pos = batch["text_rel_pos"]
encoder_mask = batch["encoder_mask"]
stop_token = batch.get("stop_token")
text_stop_token = batch.get("text_stop_token")
crossatt_rel_pos = batch.get("crossatt_rel_pos")
logits_mask = batch["y_mask"]
with torch.inference_mode():
_ = self.tts_model(
text_ids=text_token,
audio_inputs=audio_token,
text_mask=encoder_mask,
audio_mask=logits_mask,
crossatt_mask=crossatt_mask,
crossatt_rel_pos=crossatt_rel_pos,
stop_tokens=stop_token,
text_rel_pos=text_rel_pos,
text_stop_tokens=text_stop_token,
)
atts = []
for l in self.tts_model.audio_decoder.decoder_layers:
if l.crossatt is not None:
atts.append(l.crossatt.att)
num_sinks = self.tts_model.num_sink_tokens
selected_ca_heads = torch.stack(
[
atts[i][:, j].transpose(-1, -2)
for i, j in self.selected_cross_attention_heads
]
)
summed_ca = selected_ca_heads.sum(0)
avg_pool_ca = torch.nn.functional.avg_pool1d(
summed_ca[:, num_sinks:].transpose(-1, -2),
self.avg_pool_stride,
stride=self.avg_pool_stride,
ceil_mode=True,
).transpose(-1, -2)
mas_from_avg_pool = maximum_path(
avg_pool_ca.clone(),
mask=crossatt_mask[:, :-1, :: self.avg_pool_stride].transpose(-1, -2),
)
target = torch.arange(mas_from_avg_pool.shape[1]).to(mas_from_avg_pool.device)
if self.target_lag > 0:
lag = self.target_lag
mas_from_avg_pool = torch.roll(mas_from_avg_pool, lag, dims=2)
mas_from_avg_pool[:, 0, :lag] = 1.0
mas_from_avg_pool[:, 1:, :lag] = 0.0
# logits_mask[:, :lag] = False
target = (mas_from_avg_pool * target[:, None]).max(dim=1).values
sink_ca = summed_ca[:, :num_sinks]
input_ca = torch.cat((sink_ca, avg_pool_ca), dim=1)
target = target % self.model.cycle_len
return self.model(input_ca, target, logits_mask[:, :-1]), input_ca, target
def training_step(self, batch, idx):
losses, _, _ = self.step(batch, idx)
total_loss = 0.0
for name, loss in losses.items():
self.log(f"train_{name}", loss, prog_bar=True, sync_dist=True)
total_loss += loss
self.log("train_loss", total_loss, prog_bar=True, sync_dist=True)
return total_loss
def validation_step(self, batch, idx):
losses, _, _ = self.step(batch, idx)
total_loss = 0.0
for name, loss in losses.items():
self.log(f"val_{name}", loss, prog_bar=True, sync_dist=True)
total_loss += loss
self.log("val_loss", total_loss, prog_bar=True, sync_dist=True)
return total_loss
def configure_optimizers(self):
params = [
{
"params": self.model.parameters(),
"weight_decay": self.weight_decay,
}
]
opt = torch.optim.AdamW(
params,
lr=self.learning_rate,
betas=self.betas,
)
scheduler = LambdaLR(
opt,
lr_lambda=cosine_schedule_with_warmup(
warmup_steps=self.hparams.n_warmup_steps,
total_steps=self.hparams.n_training_steps,
start_lr=self.hparams.learning_rate,
end_lr=self.hparams.learning_rate * 0.1,
),
)
return [opt], [{"scheduler": scheduler, "interval": "step"}]
@hydra.main(config_path="playhead_configs/", config_name="config", version_base="1.3")
def main(cfg: DictConfig):
ptl.seed_everything(cfg.seed_everything)
model = hydra.utils.instantiate(cfg.model)
cfg.experiment_name = f"PlayHead"
datamodule = hydra.utils.instantiate(cfg.data)
trainer = hydra.utils.instantiate(cfg.trainer)
trainer.fit(model, datamodule, ckpt_path=cfg.get("ckpt_path"))
if __name__ == "__main__":
main()