File size: 1,418 Bytes
4853fdc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import torch
import torch.nn as nn
try:
import torch_npu
from torch_npu.contrib import transfer_to_npu
DEVICE_TYPE = "npu"
except ModuleNotFoundError:
DEVICE_TYPE = "cuda"
from .text_encoder import T5TextEncoder
class SketchT5TextEncoder(T5TextEncoder):
def __init__(
self, f0_dim: int , energy_dim: int, latent_dim: int,
embed_dim: int, model_name: str = "google/flan-t5-large",
):
super().__init__(
embed_dim = embed_dim,
model_name = model_name,
)
self.f0_proj = nn.Linear(f0_dim, latent_dim)
self.f0_norm = nn.LayerNorm(f0_dim)
self.energy_proj = nn.Linear(energy_dim, latent_dim)
def encode(
self,
text: list[str],
):
with torch.no_grad(), torch.amp.autocast(
device_type=DEVICE_TYPE, enabled=False
):
return super().encode(text)
def encode_sketch(
self,
f0,
energy,
):
f0_embed = self.f0_proj(self.f0_norm(f0)).unsqueeze(-1)
energy_embed = self.energy_proj(energy).unsqueeze(-1)
sketch_embed = torch.cat([f0_embed, energy_embed], dim=-1)
return {"output": sketch_embed}
if __name__ == "__main__":
text_encoder = T5TextEncoder(embed_dim=512)
text = ["a man is speaking", "a woman is singing while a dog is barking"]
output = text_encoder(text)
|