Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| def encode_lyrics(text_encoder, text_tokenizer, lyrics: str, device, dtype): | |
| """Encode lyrics into hidden states.""" | |
| lyric_inputs = text_tokenizer( | |
| lyrics, | |
| padding="max_length", | |
| max_length=512, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| lyric_input_ids = lyric_inputs.input_ids.to(device) | |
| lyric_attention_mask = lyric_inputs.attention_mask.to(device).to(dtype) | |
| # Align tensor residency to the actual text encoder device to avoid | |
| # CPU/CUDA mismatch in embedding/index_select calls. | |
| text_dev = next(text_encoder.parameters()).device | |
| if lyric_input_ids.device != text_dev: | |
| lyric_input_ids = lyric_input_ids.to(text_dev) | |
| lyric_attention_mask = lyric_attention_mask.to(text_dev) | |
| with torch.no_grad(): | |
| lyric_hidden_states = text_encoder.embed_tokens(lyric_input_ids).to(dtype) | |
| return lyric_hidden_states, lyric_attention_mask | |