Spaces:
Running
on
Zero
Running
on
Zero
File size: 944 Bytes
a602628 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import torch
def encode_lyrics(text_encoder, text_tokenizer, lyrics: str, device, dtype):
"""Encode lyrics into hidden states."""
lyric_inputs = text_tokenizer(
lyrics,
padding="max_length",
max_length=512,
truncation=True,
return_tensors="pt",
)
lyric_input_ids = lyric_inputs.input_ids.to(device)
lyric_attention_mask = lyric_inputs.attention_mask.to(device).to(dtype)
# Align tensor residency to the actual text encoder device to avoid
# CPU/CUDA mismatch in embedding/index_select calls.
text_dev = next(text_encoder.parameters()).device
if lyric_input_ids.device != text_dev:
lyric_input_ids = lyric_input_ids.to(text_dev)
lyric_attention_mask = lyric_attention_mask.to(text_dev)
with torch.no_grad():
lyric_hidden_states = text_encoder.embed_tokens(lyric_input_ids).to(dtype)
return lyric_hidden_states, lyric_attention_mask
|