File size: 1,356 Bytes
a602628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch

from acestep.constants import DEFAULT_DIT_INSTRUCTION, SFT_GEN_PROMPT

from .models import AudioSample
from .preprocess_utils import build_metas_str


def build_text_prompt(sample: AudioSample, tag_position: str, use_genre: bool) -> str:
    """Build the text prompt for the text encoder."""
    caption = sample.get_training_prompt(tag_position, use_genre=use_genre)
    metas_str = build_metas_str(sample)
    return SFT_GEN_PROMPT.format(DEFAULT_DIT_INSTRUCTION, caption, metas_str)


def encode_text(text_encoder, text_tokenizer, text_prompt: str, device, dtype):
    """Encode caption/genre prompt into text hidden states."""
    text_inputs = text_tokenizer(
        text_prompt,
        padding="max_length",
        max_length=256,
        truncation=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids.to(device)
    text_attention_mask = text_inputs.attention_mask.to(device).to(dtype)

    text_dev = next(text_encoder.parameters()).device
    if text_input_ids.device != text_dev:
        text_input_ids = text_input_ids.to(text_dev)
        text_attention_mask = text_attention_mask.to(text_dev)

    with torch.no_grad():
        text_outputs = text_encoder(text_input_ids)
        text_hidden_states = text_outputs.last_hidden_state.to(dtype)

    return text_hidden_states, text_attention_mask