File size: 944 Bytes
a602628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch


def encode_lyrics(text_encoder, text_tokenizer, lyrics: str, device, dtype):
    """Encode lyrics into hidden states."""
    lyric_inputs = text_tokenizer(
        lyrics,
        padding="max_length",
        max_length=512,
        truncation=True,
        return_tensors="pt",
    )
    lyric_input_ids = lyric_inputs.input_ids.to(device)
    lyric_attention_mask = lyric_inputs.attention_mask.to(device).to(dtype)

    # Align tensor residency to the actual text encoder device to avoid
    # CPU/CUDA mismatch in embedding/index_select calls.
    text_dev = next(text_encoder.parameters()).device
    if lyric_input_ids.device != text_dev:
        lyric_input_ids = lyric_input_ids.to(text_dev)
        lyric_attention_mask = lyric_attention_mask.to(text_dev)

    with torch.no_grad():
        lyric_hidden_states = text_encoder.embed_tokens(lyric_input_ids).to(dtype)

    return lyric_hidden_states, lyric_attention_mask