File size: 10,053 Bytes
a602628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import os
from typing import List, Tuple

import torch
from loguru import logger

from .models import AudioSample
from .preprocess_audio import load_audio_stereo
from .preprocess_context import build_context_latents
from .preprocess_encoder import run_encoder
from .preprocess_lyrics import encode_lyrics
from .preprocess_manifest import save_manifest
from .preprocess_text import build_text_prompt, encode_text
from .preprocess_utils import select_genre_indices
from .preprocess_vae import vae_encode
from acestep.debug_utils import (
    debug_log_for,
    debug_log_verbose_for,
    debug_start_verbose_for,
    debug_end_verbose_for,
)


class PreprocessMixin:
    """Preprocess labeled samples to tensor files."""

    def preprocess_to_tensors(
        self,
        dit_handler,
        output_dir: str,
        max_duration: float = 240.0,
        progress_callback=None,
    ) -> Tuple[List[str], str]:
        """Preprocess all labeled samples to tensor files for efficient training."""
        debug_log_for("dataset", f"preprocess_to_tensors: output_dir='{output_dir}', max_duration={max_duration}")
        if not self.samples:
            return [], "❌ No samples to preprocess"

        labeled_samples = [s for s in self.samples if s.labeled]
        if not labeled_samples:
            return [], "❌ No labeled samples to preprocess"

        if dit_handler is None or dit_handler.model is None:
            return [], "❌ Model not initialized. Please initialize the service first."

        os.makedirs(output_dir, exist_ok=True)

        output_paths: List[str] = []
        success_count = 0
        fail_count = 0

        model = dit_handler.model
        vae = dit_handler.vae
        text_encoder = dit_handler.text_encoder
        text_tokenizer = dit_handler.text_tokenizer
        silence_latent = dit_handler.silence_latent
        device = dit_handler.device
        dtype = dit_handler.dtype

        target_sample_rate = 48000

        genre_indices = select_genre_indices(labeled_samples, self.metadata.genre_ratio)
        debug_log_verbose_for("dataset", f"selected genre indices: count={len(genre_indices)}")

        for i, sample in enumerate(labeled_samples):
            try:
                debug_log_verbose_for("dataset", f"sample[{i}] id={sample.id} file={sample.filename}")
                if progress_callback:
                    progress_callback(f"Preprocessing {i+1}/{len(labeled_samples)}: {sample.filename}")

                use_genre = i in genre_indices

                t0 = debug_start_verbose_for("dataset", f"load_audio_stereo[{i}]")
                audio, _ = load_audio_stereo(sample.audio_path, target_sample_rate, max_duration)
                debug_end_verbose_for("dataset", f"load_audio_stereo[{i}]", t0)
                debug_log_verbose_for("dataset", f"audio shape={tuple(audio.shape)} dtype={audio.dtype}")
                audio = audio.unsqueeze(0).to(device).to(vae.dtype)
                debug_log_verbose_for(
                    "dataset",
                    f"vae device={next(vae.parameters()).device} vae dtype={vae.dtype} "
                    f"audio device={audio.device} audio dtype={audio.dtype}",
                )

                with torch.no_grad():
                    t0 = debug_start_verbose_for("dataset", f"vae_encode[{i}]")
                    target_latents = vae_encode(vae, audio, dtype)
                    debug_end_verbose_for("dataset", f"vae_encode[{i}]", t0)

                latent_length = target_latents.shape[1]
                attention_mask = torch.ones(1, latent_length, device=device, dtype=dtype)
                debug_log_verbose_for(
                    "dataset",
                    f"target_latents shape={tuple(target_latents.shape)} latent_length={latent_length}",
                )

                caption = sample.get_training_prompt(self.metadata.tag_position, use_genre=use_genre)
                text_prompt = build_text_prompt(sample, self.metadata.tag_position, use_genre)

                if i == 0:
                    logger.info(f"\n{'='*70}")
                    logger.info("πŸ” [DEBUG] DiT TEXT ENCODER INPUT (Training Preprocess)")
                    logger.info(f"{'='*70}")
                    logger.info(f"text_prompt:\n{text_prompt}")
                    logger.info(f"{'='*70}\n")

                t0 = debug_start_verbose_for("dataset", f"encode_text[{i}]")
                text_hidden_states, text_attention_mask = encode_text(
                    text_encoder, text_tokenizer, text_prompt, device, dtype
                )
                debug_end_verbose_for("dataset", f"encode_text[{i}]", t0)
                debug_log_verbose_for(
                    "dataset",
                    f"text_hidden_states shape={tuple(text_hidden_states.shape)} "
                    f"text_attention_mask shape={tuple(text_attention_mask.shape)}",
                )

                lyrics = sample.lyrics if sample.lyrics else "[Instrumental]"
                t0 = debug_start_verbose_for("dataset", f"encode_lyrics[{i}]")
                lyric_hidden_states, lyric_attention_mask = encode_lyrics(
                    text_encoder, text_tokenizer, lyrics, device, dtype
                )
                debug_end_verbose_for("dataset", f"encode_lyrics[{i}]", t0)
                debug_log_verbose_for(
                    "dataset",
                    f"lyric_hidden_states shape={tuple(lyric_hidden_states.shape)} "
                    f"lyric_attention_mask shape={tuple(lyric_attention_mask.shape)}",
                )

                t0 = debug_start_verbose_for("dataset", f"run_encoder[{i}]")
                # Ensure DiT encoder runs on the active residency device (GPU when loaded via
                # offload context). This prevents flash-attn CPU backend crashes.
                with dit_handler._load_model_context("model"):
                    model_device = next(model.parameters()).device
                    model_dtype = next(model.parameters()).dtype
                    if text_hidden_states.device != model_device:
                        text_hidden_states = text_hidden_states.to(model_device)
                    if text_attention_mask.device != model_device:
                        text_attention_mask = text_attention_mask.to(model_device)
                    if lyric_hidden_states.device != model_device:
                        lyric_hidden_states = lyric_hidden_states.to(model_device)
                    if lyric_attention_mask.device != model_device:
                        lyric_attention_mask = lyric_attention_mask.to(model_device)
                    if text_hidden_states.dtype != model_dtype:
                        text_hidden_states = text_hidden_states.to(model_dtype)
                    if lyric_hidden_states.dtype != model_dtype:
                        lyric_hidden_states = lyric_hidden_states.to(model_dtype)

                    encoder_hidden_states, encoder_attention_mask = run_encoder(
                        model,
                        text_hidden_states=text_hidden_states,
                        text_attention_mask=text_attention_mask,
                        lyric_hidden_states=lyric_hidden_states,
                        lyric_attention_mask=lyric_attention_mask,
                        device=model_device,
                        dtype=model_dtype,
                    )
                debug_end_verbose_for("dataset", f"run_encoder[{i}]", t0)
                debug_log_verbose_for(
                    "dataset",
                    f"encoder_hidden_states shape={tuple(encoder_hidden_states.shape)} "
                    f"encoder_attention_mask shape={tuple(encoder_attention_mask.shape)}",
                )

                t0 = debug_start_verbose_for("dataset", f"build_context_latents[{i}]")
                context_latents = build_context_latents(silence_latent, latent_length, device, dtype)
                debug_end_verbose_for("dataset", f"build_context_latents[{i}]", t0)

                output_data = {
                    "target_latents": target_latents.squeeze(0).cpu(),
                    "attention_mask": attention_mask.squeeze(0).cpu(),
                    "encoder_hidden_states": encoder_hidden_states.squeeze(0).cpu(),
                    "encoder_attention_mask": encoder_attention_mask.squeeze(0).cpu(),
                    "context_latents": context_latents.squeeze(0).cpu(),
                    "metadata": {
                        "audio_path": sample.audio_path,
                        "filename": sample.filename,
                        "caption": caption,
                        "lyrics": lyrics,
                        "duration": sample.duration,
                        "bpm": sample.bpm,
                        "keyscale": sample.keyscale,
                        "timesignature": sample.timesignature,
                        "language": sample.language,
                        "is_instrumental": sample.is_instrumental,
                    },
                }

                output_path = os.path.join(output_dir, f"{sample.id}.pt")
                t0 = debug_start_verbose_for("dataset", f"torch.save[{i}]")
                torch.save(output_data, output_path)
                debug_end_verbose_for("dataset", f"torch.save[{i}]", t0)
                output_paths.append(output_path)
                success_count += 1

            except Exception as e:
                logger.exception(f"Error preprocessing {sample.filename}")
                fail_count += 1
                if progress_callback:
                    progress_callback(f"❌ Failed: {sample.filename}: {str(e)}")

        t0 = debug_start_verbose_for("dataset", "save_manifest")
        save_manifest(output_dir, self.metadata, output_paths)
        debug_end_verbose_for("dataset", "save_manifest", t0)

        status = f"βœ… Preprocessed {success_count}/{len(labeled_samples)} samples to {output_dir}"
        if fail_count > 0:
            status += f" ({fail_count} failed)"

        return output_paths, status