| | from typing import Any, Optional |
| |
|
| | import lightning as L |
| | import torch |
| | import torch.nn.functional as F |
| | from lightning.pytorch.utilities.types import OptimizerLRScheduler |
| |
|
| | import fish_speech.utils as utils |
| | from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID |
| | from fish_speech.models.text2semantic.llama import NaiveTransformer |
| |
|
| | log = utils.RankedLogger(__name__, rank_zero_only=True) |
| |
|
| |
|
| | class TextToSemantic(L.LightningModule): |
| | def __init__( |
| | self, |
| | model: NaiveTransformer, |
| | optimizer: Any, |
| | lr_scheduler: Any, |
| | ): |
| | super().__init__() |
| |
|
| | self.model = model |
| | self.optimizer_builder = optimizer |
| | self.lr_scheduler_builder = lr_scheduler |
| |
|
| | def forward(self, x): |
| | return self.model(x) |
| |
|
| | def on_save_checkpoint(self, checkpoint): |
| | |
| | state_dict = checkpoint["state_dict"] |
| | use_lora = any("lora" in name for name in state_dict.keys()) |
| | if not use_lora: |
| | return |
| |
|
| | for name in list(state_dict.keys()): |
| | if "lora" not in name: |
| | state_dict.pop(name) |
| |
|
| | def configure_optimizers(self) -> OptimizerLRScheduler: |
| | |
| | weight_decay_parameters, other_parameters = [], [] |
| | for name, param in self.named_parameters(): |
| | if ".bias" in name or "norm.weight" in name or ".embeddings." in name: |
| | other_parameters.append(param) |
| | else: |
| | weight_decay_parameters.append(param) |
| |
|
| | optimizer = self.optimizer_builder( |
| | [ |
| | {"params": weight_decay_parameters}, |
| | {"params": other_parameters, "weight_decay": 0.0}, |
| | ] |
| | ) |
| |
|
| | |
| | for i in optimizer.param_groups: |
| | log.info( |
| | f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters" |
| | ) |
| |
|
| | lr_scheduler = self.lr_scheduler_builder(optimizer) |
| |
|
| | return { |
| | "optimizer": optimizer, |
| | "lr_scheduler": { |
| | "scheduler": lr_scheduler, |
| | "interval": "step", |
| | }, |
| | } |
| |
|
| | |
| | def get_batch_logps( |
| | self, |
| | logits: torch.FloatTensor, |
| | labels: torch.LongTensor, |
| | average_log_prob: bool = False, |
| | ) -> torch.FloatTensor: |
| | """Compute the log probabilities of the given labels under the given logits. |
| | |
| | Args: |
| | logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size) |
| | labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size) |
| | average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. |
| | |
| | Returns: |
| | A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. |
| | """ |
| | assert logits.shape[:-1] == labels.shape |
| |
|
| | labels = labels.clone() |
| | loss_mask = labels != -100 |
| |
|
| | |
| | labels[labels == -100] = 0 |
| |
|
| | per_token_logps = torch.gather( |
| | logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1) |
| | ).squeeze(-1) |
| |
|
| | if average_log_prob: |
| | return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) |
| | else: |
| | return (per_token_logps * loss_mask).sum(-1) |
| |
|
| | def _step(self, batch, batch_idx, stage: str): |
| | is_train = stage == "train" |
| |
|
| | if is_train: |
| | |
| | |
| | self.model.train() |
| |
|
| | |
| | labels = batch["labels"] |
| | outputs = self.model( |
| | inp=batch["inputs"], |
| | key_padding_mask=batch["attention_masks"], |
| | ) |
| | token_logits = outputs.token_logits |
| | codebook_logits = outputs.codebook_logits |
| |
|
| | |
| | base_loss = F.cross_entropy( |
| | token_logits.view(-1, token_logits.size(-1)), |
| | labels[:, 0].reshape(-1), |
| | ignore_index=-100, |
| | ) |
| |
|
| | codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT |
| | semantic_loss = F.cross_entropy( |
| | codebook_logits.view(-1, codebook_logits.size(-1)), |
| | codebook_labels.reshape(-1), |
| | ignore_index=-100, |
| | ) |
| |
|
| | loss = base_loss + semantic_loss |
| |
|
| | self.log( |
| | f"{stage}/loss", |
| | loss, |
| | on_step=is_train, |
| | on_epoch=not is_train, |
| | prog_bar=True, |
| | logger=True, |
| | sync_dist=not is_train, |
| | ) |
| |
|
| | self.log( |
| | f"{stage}/base_loss", |
| | base_loss, |
| | on_step=is_train, |
| | on_epoch=not is_train, |
| | prog_bar=False, |
| | logger=True, |
| | sync_dist=not is_train, |
| | ) |
| |
|
| | self.log( |
| | f"{stage}/semantic_loss", |
| | semantic_loss, |
| | on_step=is_train, |
| | on_epoch=not is_train, |
| | prog_bar=False, |
| | logger=True, |
| | sync_dist=not is_train, |
| | ) |
| |
|
| | |
| | accuracy = self.get_accuracy(codebook_logits, codebook_labels) |
| | self.log( |
| | f"{stage}/top_5_accuracy", |
| | accuracy, |
| | on_step=is_train, |
| | on_epoch=not is_train, |
| | prog_bar=True, |
| | logger=True, |
| | sync_dist=not is_train, |
| | ) |
| |
|
| | return loss |
| |
|
| | def get_accuracy(self, logits, labels): |
| | mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID) |
| | if mask.sum() == 0: |
| | return torch.tensor(0.0, device=logits.device) |
| |
|
| | _, indices = logits.topk(5, dim=-1) |
| | correct = indices.eq(labels.unsqueeze(-1)) |
| | correct[~mask] = 0 |
| | correct = correct.sum() |
| | accuracy = correct / mask.sum() |
| |
|
| | return accuracy |
| |
|
| | def training_step(self, batch, batch_idx): |
| | return self._step(batch, batch_idx, "train") |
| |
|
| | def validation_step(self, batch, batch_idx): |
| | return self._step(batch, batch_idx, "val") |
| |
|