ACE-Step Custom
Deploy ACE-Step Custom Edition with bug fixes
a602628
import torch
from acestep.constants import DEFAULT_DIT_INSTRUCTION, SFT_GEN_PROMPT
from .models import AudioSample
from .preprocess_utils import build_metas_str
def build_text_prompt(sample: AudioSample, tag_position: str, use_genre: bool) -> str:
"""Build the text prompt for the text encoder."""
caption = sample.get_training_prompt(tag_position, use_genre=use_genre)
metas_str = build_metas_str(sample)
return SFT_GEN_PROMPT.format(DEFAULT_DIT_INSTRUCTION, caption, metas_str)
def encode_text(text_encoder, text_tokenizer, text_prompt: str, device, dtype):
"""Encode caption/genre prompt into text hidden states."""
text_inputs = text_tokenizer(
text_prompt,
padding="max_length",
max_length=256,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
text_attention_mask = text_inputs.attention_mask.to(device).to(dtype)
text_dev = next(text_encoder.parameters()).device
if text_input_ids.device != text_dev:
text_input_ids = text_input_ids.to(text_dev)
text_attention_mask = text_attention_mask.to(text_dev)
with torch.no_grad():
text_outputs = text_encoder(text_input_ids)
text_hidden_states = text_outputs.last_hidden_state.to(dtype)
return text_hidden_states, text_attention_mask