Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| from typing import List, Tuple | |
| import torch | |
| from loguru import logger | |
| from .models import AudioSample | |
| from .preprocess_audio import load_audio_stereo | |
| from .preprocess_context import build_context_latents | |
| from .preprocess_encoder import run_encoder | |
| from .preprocess_lyrics import encode_lyrics | |
| from .preprocess_manifest import save_manifest | |
| from .preprocess_text import build_text_prompt, encode_text | |
| from .preprocess_utils import select_genre_indices | |
| from .preprocess_vae import vae_encode | |
| from acestep.debug_utils import ( | |
| debug_log_for, | |
| debug_log_verbose_for, | |
| debug_start_verbose_for, | |
| debug_end_verbose_for, | |
| ) | |
| class PreprocessMixin: | |
| """Preprocess labeled samples to tensor files.""" | |
| def preprocess_to_tensors( | |
| self, | |
| dit_handler, | |
| output_dir: str, | |
| max_duration: float = 240.0, | |
| progress_callback=None, | |
| ) -> Tuple[List[str], str]: | |
| """Preprocess all labeled samples to tensor files for efficient training.""" | |
| debug_log_for("dataset", f"preprocess_to_tensors: output_dir='{output_dir}', max_duration={max_duration}") | |
| if not self.samples: | |
| return [], "β No samples to preprocess" | |
| labeled_samples = [s for s in self.samples if s.labeled] | |
| if not labeled_samples: | |
| return [], "β No labeled samples to preprocess" | |
| if dit_handler is None or dit_handler.model is None: | |
| return [], "β Model not initialized. Please initialize the service first." | |
| os.makedirs(output_dir, exist_ok=True) | |
| output_paths: List[str] = [] | |
| success_count = 0 | |
| fail_count = 0 | |
| model = dit_handler.model | |
| vae = dit_handler.vae | |
| text_encoder = dit_handler.text_encoder | |
| text_tokenizer = dit_handler.text_tokenizer | |
| silence_latent = dit_handler.silence_latent | |
| device = dit_handler.device | |
| dtype = dit_handler.dtype | |
| target_sample_rate = 48000 | |
| genre_indices = select_genre_indices(labeled_samples, self.metadata.genre_ratio) | |
| debug_log_verbose_for("dataset", f"selected genre indices: count={len(genre_indices)}") | |
| for i, sample in enumerate(labeled_samples): | |
| try: | |
| debug_log_verbose_for("dataset", f"sample[{i}] id={sample.id} file={sample.filename}") | |
| if progress_callback: | |
| progress_callback(f"Preprocessing {i+1}/{len(labeled_samples)}: {sample.filename}") | |
| use_genre = i in genre_indices | |
| t0 = debug_start_verbose_for("dataset", f"load_audio_stereo[{i}]") | |
| audio, _ = load_audio_stereo(sample.audio_path, target_sample_rate, max_duration) | |
| debug_end_verbose_for("dataset", f"load_audio_stereo[{i}]", t0) | |
| debug_log_verbose_for("dataset", f"audio shape={tuple(audio.shape)} dtype={audio.dtype}") | |
| audio = audio.unsqueeze(0).to(device).to(vae.dtype) | |
| debug_log_verbose_for( | |
| "dataset", | |
| f"vae device={next(vae.parameters()).device} vae dtype={vae.dtype} " | |
| f"audio device={audio.device} audio dtype={audio.dtype}", | |
| ) | |
| with torch.no_grad(): | |
| t0 = debug_start_verbose_for("dataset", f"vae_encode[{i}]") | |
| target_latents = vae_encode(vae, audio, dtype) | |
| debug_end_verbose_for("dataset", f"vae_encode[{i}]", t0) | |
| latent_length = target_latents.shape[1] | |
| attention_mask = torch.ones(1, latent_length, device=device, dtype=dtype) | |
| debug_log_verbose_for( | |
| "dataset", | |
| f"target_latents shape={tuple(target_latents.shape)} latent_length={latent_length}", | |
| ) | |
| caption = sample.get_training_prompt(self.metadata.tag_position, use_genre=use_genre) | |
| text_prompt = build_text_prompt(sample, self.metadata.tag_position, use_genre) | |
| if i == 0: | |
| logger.info(f"\n{'='*70}") | |
| logger.info("π [DEBUG] DiT TEXT ENCODER INPUT (Training Preprocess)") | |
| logger.info(f"{'='*70}") | |
| logger.info(f"text_prompt:\n{text_prompt}") | |
| logger.info(f"{'='*70}\n") | |
| t0 = debug_start_verbose_for("dataset", f"encode_text[{i}]") | |
| text_hidden_states, text_attention_mask = encode_text( | |
| text_encoder, text_tokenizer, text_prompt, device, dtype | |
| ) | |
| debug_end_verbose_for("dataset", f"encode_text[{i}]", t0) | |
| debug_log_verbose_for( | |
| "dataset", | |
| f"text_hidden_states shape={tuple(text_hidden_states.shape)} " | |
| f"text_attention_mask shape={tuple(text_attention_mask.shape)}", | |
| ) | |
| lyrics = sample.lyrics if sample.lyrics else "[Instrumental]" | |
| t0 = debug_start_verbose_for("dataset", f"encode_lyrics[{i}]") | |
| lyric_hidden_states, lyric_attention_mask = encode_lyrics( | |
| text_encoder, text_tokenizer, lyrics, device, dtype | |
| ) | |
| debug_end_verbose_for("dataset", f"encode_lyrics[{i}]", t0) | |
| debug_log_verbose_for( | |
| "dataset", | |
| f"lyric_hidden_states shape={tuple(lyric_hidden_states.shape)} " | |
| f"lyric_attention_mask shape={tuple(lyric_attention_mask.shape)}", | |
| ) | |
| t0 = debug_start_verbose_for("dataset", f"run_encoder[{i}]") | |
| # Ensure DiT encoder runs on the active residency device (GPU when loaded via | |
| # offload context). This prevents flash-attn CPU backend crashes. | |
| with dit_handler._load_model_context("model"): | |
| model_device = next(model.parameters()).device | |
| model_dtype = next(model.parameters()).dtype | |
| if text_hidden_states.device != model_device: | |
| text_hidden_states = text_hidden_states.to(model_device) | |
| if text_attention_mask.device != model_device: | |
| text_attention_mask = text_attention_mask.to(model_device) | |
| if lyric_hidden_states.device != model_device: | |
| lyric_hidden_states = lyric_hidden_states.to(model_device) | |
| if lyric_attention_mask.device != model_device: | |
| lyric_attention_mask = lyric_attention_mask.to(model_device) | |
| if text_hidden_states.dtype != model_dtype: | |
| text_hidden_states = text_hidden_states.to(model_dtype) | |
| if lyric_hidden_states.dtype != model_dtype: | |
| lyric_hidden_states = lyric_hidden_states.to(model_dtype) | |
| encoder_hidden_states, encoder_attention_mask = run_encoder( | |
| model, | |
| text_hidden_states=text_hidden_states, | |
| text_attention_mask=text_attention_mask, | |
| lyric_hidden_states=lyric_hidden_states, | |
| lyric_attention_mask=lyric_attention_mask, | |
| device=model_device, | |
| dtype=model_dtype, | |
| ) | |
| debug_end_verbose_for("dataset", f"run_encoder[{i}]", t0) | |
| debug_log_verbose_for( | |
| "dataset", | |
| f"encoder_hidden_states shape={tuple(encoder_hidden_states.shape)} " | |
| f"encoder_attention_mask shape={tuple(encoder_attention_mask.shape)}", | |
| ) | |
| t0 = debug_start_verbose_for("dataset", f"build_context_latents[{i}]") | |
| context_latents = build_context_latents(silence_latent, latent_length, device, dtype) | |
| debug_end_verbose_for("dataset", f"build_context_latents[{i}]", t0) | |
| output_data = { | |
| "target_latents": target_latents.squeeze(0).cpu(), | |
| "attention_mask": attention_mask.squeeze(0).cpu(), | |
| "encoder_hidden_states": encoder_hidden_states.squeeze(0).cpu(), | |
| "encoder_attention_mask": encoder_attention_mask.squeeze(0).cpu(), | |
| "context_latents": context_latents.squeeze(0).cpu(), | |
| "metadata": { | |
| "audio_path": sample.audio_path, | |
| "filename": sample.filename, | |
| "caption": caption, | |
| "lyrics": lyrics, | |
| "duration": sample.duration, | |
| "bpm": sample.bpm, | |
| "keyscale": sample.keyscale, | |
| "timesignature": sample.timesignature, | |
| "language": sample.language, | |
| "is_instrumental": sample.is_instrumental, | |
| }, | |
| } | |
| output_path = os.path.join(output_dir, f"{sample.id}.pt") | |
| t0 = debug_start_verbose_for("dataset", f"torch.save[{i}]") | |
| torch.save(output_data, output_path) | |
| debug_end_verbose_for("dataset", f"torch.save[{i}]", t0) | |
| output_paths.append(output_path) | |
| success_count += 1 | |
| except Exception as e: | |
| logger.exception(f"Error preprocessing {sample.filename}") | |
| fail_count += 1 | |
| if progress_callback: | |
| progress_callback(f"β Failed: {sample.filename}: {str(e)}") | |
| t0 = debug_start_verbose_for("dataset", "save_manifest") | |
| save_manifest(output_dir, self.metadata, output_paths) | |
| debug_end_verbose_for("dataset", "save_manifest", t0) | |
| status = f"β Preprocessed {success_count}/{len(labeled_samples)} samples to {output_dir}" | |
| if fail_count > 0: | |
| status += f" ({fail_count} failed)" | |
| return output_paths, status | |