Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from acestep.constants import DEFAULT_DIT_INSTRUCTION, SFT_GEN_PROMPT | |
| from .models import AudioSample | |
| from .preprocess_utils import build_metas_str | |
| def build_text_prompt(sample: AudioSample, tag_position: str, use_genre: bool) -> str: | |
| """Build the text prompt for the text encoder.""" | |
| caption = sample.get_training_prompt(tag_position, use_genre=use_genre) | |
| metas_str = build_metas_str(sample) | |
| return SFT_GEN_PROMPT.format(DEFAULT_DIT_INSTRUCTION, caption, metas_str) | |
| def encode_text(text_encoder, text_tokenizer, text_prompt: str, device, dtype): | |
| """Encode caption/genre prompt into text hidden states.""" | |
| text_inputs = text_tokenizer( | |
| text_prompt, | |
| padding="max_length", | |
| max_length=256, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids.to(device) | |
| text_attention_mask = text_inputs.attention_mask.to(device).to(dtype) | |
| text_dev = next(text_encoder.parameters()).device | |
| if text_input_ids.device != text_dev: | |
| text_input_ids = text_input_ids.to(text_dev) | |
| text_attention_mask = text_attention_mask.to(text_dev) | |
| with torch.no_grad(): | |
| text_outputs = text_encoder(text_input_ids) | |
| text_hidden_states = text_outputs.last_hidden_state.to(dtype) | |
| return text_hidden_states, text_attention_mask | |