Spaces:
Running
on
Zero
Running
on
Zero
| import math | |
| import hydra | |
| import pytorch_lightning as ptl | |
| import torch | |
| from omegaconf import DictConfig | |
| from super_monotonic_align import maximum_path | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from model.config import PlayHeadConfig | |
| from playhead import PlayHead | |
| from train_tts import TrainARTTS | |
| def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr): | |
| def lr_lambda(step): | |
| if step < warmup_steps: | |
| return step / max(1, warmup_steps) | |
| progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) | |
| cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) | |
| return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr | |
| return lr_lambda | |
| def expand(x, r): | |
| b, n, d = x.shape | |
| return x.unsqueeze(2).repeat(1, 1, r, 1).reshape(b, r * n, d) | |
| class TrainPlayHead(ptl.LightningModule): | |
| def __init__( | |
| self, | |
| tts_checkpoint_path: str, | |
| playhead_config: PlayHeadConfig, | |
| learning_rate: float = 5e-4, | |
| end_learning_rate: float | None = None, | |
| weight_decay: float = 0.1, | |
| betas: tuple[float, float] = (0.9, 0.999), | |
| n_warmup_steps: int = 500, | |
| n_training_steps: int = 300000, | |
| ): | |
| super(TrainPlayHead, self).__init__() | |
| cfg = playhead_config | |
| self.learning_rate = learning_rate | |
| self.weight_decay = weight_decay | |
| self.betas = betas | |
| self.n_warmup_steps = n_warmup_steps | |
| self.n_training_steps = n_training_steps | |
| self.selected_cross_attention_heads = cfg.selected_cross_attention_heads | |
| self.avg_pool_stride = cfg.avg_pool_stride | |
| self.target_lag = cfg.target_lag | |
| self.save_hyperparameters() | |
| self.model = PlayHead(playhead_config) | |
| tts_lightning_module = TrainARTTS.load_from_checkpoint(tts_checkpoint_path) | |
| self.tts_model = tts_lightning_module.model.eval() | |
| for p in self.tts_model.parameters(): | |
| p.requires_grad = False | |
| def on_train_epoch_start(self): | |
| if hasattr(self.trainer.train_dataloader.batch_sampler, "set_epoch"): | |
| self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch) | |
| def save_model_weights_and_config( | |
| self, | |
| dir: str | None, | |
| model_filename: str = "model.st", | |
| config_filename: str = "config.json", | |
| ): | |
| # cfg = self.hparams.config | |
| # Path(dir).mkdir(exist_ok=True) | |
| # model_path = Path(dir) / model_filename | |
| # save_file(self.model.state_dict(), model_path) | |
| # with open(Path(dir) / config_filename, "w") as f: | |
| # json.dump(asdict(cfg), f, indent=2) | |
| pass | |
| def step(self, batch, batch_idx: int, validation: bool = False): | |
| text_token = batch["text_token"] | |
| audio_token = batch["audio_token"].squeeze(2) | |
| crossatt_mask = batch["crossatt_mask"] | |
| text_rel_pos = batch["text_rel_pos"] | |
| encoder_mask = batch["encoder_mask"] | |
| stop_token = batch.get("stop_token") | |
| text_stop_token = batch.get("text_stop_token") | |
| crossatt_rel_pos = batch.get("crossatt_rel_pos") | |
| logits_mask = batch["y_mask"] | |
| with torch.inference_mode(): | |
| _ = self.tts_model( | |
| text_ids=text_token, | |
| audio_inputs=audio_token, | |
| text_mask=encoder_mask, | |
| audio_mask=logits_mask, | |
| crossatt_mask=crossatt_mask, | |
| crossatt_rel_pos=crossatt_rel_pos, | |
| stop_tokens=stop_token, | |
| text_rel_pos=text_rel_pos, | |
| text_stop_tokens=text_stop_token, | |
| ) | |
| atts = [] | |
| for l in self.tts_model.audio_decoder.decoder_layers: | |
| if l.crossatt is not None: | |
| atts.append(l.crossatt.att) | |
| num_sinks = self.tts_model.num_sink_tokens | |
| selected_ca_heads = torch.stack( | |
| [ | |
| atts[i][:, j].transpose(-1, -2) | |
| for i, j in self.selected_cross_attention_heads | |
| ] | |
| ) | |
| summed_ca = selected_ca_heads.sum(0) | |
| avg_pool_ca = torch.nn.functional.avg_pool1d( | |
| summed_ca[:, num_sinks:].transpose(-1, -2), | |
| self.avg_pool_stride, | |
| stride=self.avg_pool_stride, | |
| ceil_mode=True, | |
| ).transpose(-1, -2) | |
| mas_from_avg_pool = maximum_path( | |
| avg_pool_ca.clone(), | |
| mask=crossatt_mask[:, :-1, :: self.avg_pool_stride].transpose(-1, -2), | |
| ) | |
| target = torch.arange(mas_from_avg_pool.shape[1]).to(mas_from_avg_pool.device) | |
| if self.target_lag > 0: | |
| lag = self.target_lag | |
| mas_from_avg_pool = torch.roll(mas_from_avg_pool, lag, dims=2) | |
| mas_from_avg_pool[:, 0, :lag] = 1.0 | |
| mas_from_avg_pool[:, 1:, :lag] = 0.0 | |
| # logits_mask[:, :lag] = False | |
| target = (mas_from_avg_pool * target[:, None]).max(dim=1).values | |
| sink_ca = summed_ca[:, :num_sinks] | |
| input_ca = torch.cat((sink_ca, avg_pool_ca), dim=1) | |
| target = target % self.model.cycle_len | |
| return self.model(input_ca, target, logits_mask[:, :-1]), input_ca, target | |
| def training_step(self, batch, idx): | |
| losses, _, _ = self.step(batch, idx) | |
| total_loss = 0.0 | |
| for name, loss in losses.items(): | |
| self.log(f"train_{name}", loss, prog_bar=True, sync_dist=True) | |
| total_loss += loss | |
| self.log("train_loss", total_loss, prog_bar=True, sync_dist=True) | |
| return total_loss | |
| def validation_step(self, batch, idx): | |
| losses, _, _ = self.step(batch, idx) | |
| total_loss = 0.0 | |
| for name, loss in losses.items(): | |
| self.log(f"val_{name}", loss, prog_bar=True, sync_dist=True) | |
| total_loss += loss | |
| self.log("val_loss", total_loss, prog_bar=True, sync_dist=True) | |
| return total_loss | |
| def configure_optimizers(self): | |
| params = [ | |
| { | |
| "params": self.model.parameters(), | |
| "weight_decay": self.weight_decay, | |
| } | |
| ] | |
| opt = torch.optim.AdamW( | |
| params, | |
| lr=self.learning_rate, | |
| betas=self.betas, | |
| ) | |
| scheduler = LambdaLR( | |
| opt, | |
| lr_lambda=cosine_schedule_with_warmup( | |
| warmup_steps=self.hparams.n_warmup_steps, | |
| total_steps=self.hparams.n_training_steps, | |
| start_lr=self.hparams.learning_rate, | |
| end_lr=self.hparams.learning_rate * 0.1, | |
| ), | |
| ) | |
| return [opt], [{"scheduler": scheduler, "interval": "step"}] | |
| def main(cfg: DictConfig): | |
| ptl.seed_everything(cfg.seed_everything) | |
| model = hydra.utils.instantiate(cfg.model) | |
| cfg.experiment_name = f"PlayHead" | |
| datamodule = hydra.utils.instantiate(cfg.data) | |
| trainer = hydra.utils.instantiate(cfg.trainer) | |
| trainer.fit(model, datamodule, ckpt_path=cfg.get("ckpt_path")) | |
| if __name__ == "__main__": | |
| main() | |