import torch from acestep.constants import DEFAULT_DIT_INSTRUCTION, SFT_GEN_PROMPT from .models import AudioSample from .preprocess_utils import build_metas_str def build_text_prompt(sample: AudioSample, tag_position: str, use_genre: bool) -> str: """Build the text prompt for the text encoder.""" caption = sample.get_training_prompt(tag_position, use_genre=use_genre) metas_str = build_metas_str(sample) return SFT_GEN_PROMPT.format(DEFAULT_DIT_INSTRUCTION, caption, metas_str) def encode_text(text_encoder, text_tokenizer, text_prompt: str, device, dtype): """Encode caption/genre prompt into text hidden states.""" text_inputs = text_tokenizer( text_prompt, padding="max_length", max_length=256, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(device) text_attention_mask = text_inputs.attention_mask.to(device).to(dtype) text_dev = next(text_encoder.parameters()).device if text_input_ids.device != text_dev: text_input_ids = text_input_ids.to(text_dev) text_attention_mask = text_attention_mask.to(text_dev) with torch.no_grad(): text_outputs = text_encoder(text_input_ids) text_hidden_states = text_outputs.last_hidden_state.to(dtype) return text_hidden_states, text_attention_mask