Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,356 Bytes
a602628 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch
from acestep.constants import DEFAULT_DIT_INSTRUCTION, SFT_GEN_PROMPT
from .models import AudioSample
from .preprocess_utils import build_metas_str
def build_text_prompt(sample: AudioSample, tag_position: str, use_genre: bool) -> str:
"""Build the text prompt for the text encoder."""
caption = sample.get_training_prompt(tag_position, use_genre=use_genre)
metas_str = build_metas_str(sample)
return SFT_GEN_PROMPT.format(DEFAULT_DIT_INSTRUCTION, caption, metas_str)
def encode_text(text_encoder, text_tokenizer, text_prompt: str, device, dtype):
"""Encode caption/genre prompt into text hidden states."""
text_inputs = text_tokenizer(
text_prompt,
padding="max_length",
max_length=256,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
text_attention_mask = text_inputs.attention_mask.to(device).to(dtype)
text_dev = next(text_encoder.parameters()).device
if text_input_ids.device != text_dev:
text_input_ids = text_input_ids.to(text_dev)
text_attention_mask = text_attention_mask.to(text_dev)
with torch.no_grad():
text_outputs = text_encoder(text_input_ids)
text_hidden_states = text_outputs.last_hidden_state.to(dtype)
return text_hidden_states, text_attention_mask
|