Spaces:
Running
on
Zero
Running
on
Zero
| from dataclasses import dataclass | |
| from typing import Optional | |
| import torch | |
| from torch import nn, Tensor | |
| from .perceiver import Perceiver | |
| from .t3_config import T3Config | |
| class T3Cond: | |
| """ | |
| Dataclass container for most / all conditioning info. | |
| TODO: serialization methods aren't used, keeping them around for convenience | |
| """ | |
| speaker_emb: Tensor | |
| clap_emb: Optional[Tensor] = None | |
| cond_prompt_speech_tokens: Optional[Tensor] = None | |
| cond_prompt_speech_emb: Optional[Tensor] = None | |
| emotion_adv: Optional[Tensor] = 0.5 | |
| def to(self, *, device=None, dtype=None): | |
| "Cast to a device and dtype. Dtype casting is ignored for long/int tensors." | |
| for k, v in self.__dict__.items(): | |
| if torch.is_tensor(v): | |
| is_fp = type(v.view(-1)[0].item()) is not int | |
| setattr(self, k, v.to(device=device, dtype=dtype if is_fp else None)) | |
| return self | |
| def save(self, fpath): | |
| torch.save(self.__dict__, fpath) | |
| def load(fpath, map_location="cpu"): | |
| kwargs = torch.load(fpath, map_location=map_location, weights_only=True) | |
| return T3Cond(**kwargs) | |
| class T3CondEnc(nn.Module): | |
| """ | |
| Handle all non-text conditioning, like speaker embeddings / prompts, CLAP, emotion, etc. | |
| """ | |
| def __init__(self, hp: T3Config): | |
| super().__init__() | |
| self.hp = hp | |
| if hp.encoder_type == "voice_encoder": | |
| self.spkr_enc = nn.Linear(hp.speaker_embed_size, hp.n_channels) | |
| else: | |
| raise NotImplementedError(str(hp.encoder_type)) | |
| # emotion adv | |
| self.emotion_adv_fc = None | |
| if hp.emotion_adv: | |
| self.emotion_adv_fc = nn.Linear(1, hp.n_channels, bias=False) | |
| # perceiver resampler | |
| self.perceiver = None | |
| if hp.use_perceiver_resampler: | |
| self.perceiver = Perceiver() | |
| def forward(self, cond: T3Cond): | |
| # Validate | |
| assert (cond.cond_prompt_speech_tokens is None) == (cond.cond_prompt_speech_emb is None), \ | |
| "no embeddings for cond_prompt_speech_tokens" | |
| # Speaker embedding projection | |
| cond_spkr = self.spkr_enc(cond.speaker_emb.view(-1, self.hp.speaker_embed_size))[:, None] # (B, 1, dim) | |
| empty = torch.zeros_like(cond_spkr[:, :0]) # (B, 0, dim) | |
| # TODO CLAP | |
| assert cond.clap_emb is None, "clap_embed not implemented" | |
| cond_clap = empty # (B, 0, dim) | |
| # Cond prompt | |
| cond_prompt_speech_emb = cond.cond_prompt_speech_emb | |
| if cond_prompt_speech_emb is None: | |
| cond_prompt_speech_emb = empty # (B, 0, dim) | |
| elif self.hp.use_perceiver_resampler: | |
| cond_prompt_speech_emb = self.perceiver(cond_prompt_speech_emb) | |
| # Emotion Adv: must provide a value if this model uses emotion conditioning | |
| cond_emotion_adv = empty # (B, 0, dim) | |
| if self.hp.emotion_adv: | |
| assert cond.emotion_adv is not None | |
| cond_emotion_adv = self.emotion_adv_fc(cond.emotion_adv.view(-1, 1, 1)) | |
| # Concat and return | |
| cond_embeds = torch.cat(( | |
| cond_spkr, | |
| cond_clap, | |
| cond_prompt_speech_emb, | |
| cond_emotion_adv, | |
| ), dim=1) | |
| return cond_embeds | |