|
|
""" |
|
|
Dataset Builder for LoRA Training |
|
|
|
|
|
Provides functionality to: |
|
|
1. Scan directories for audio files |
|
|
2. Auto-label audio using LLM |
|
|
3. Preview and edit metadata |
|
|
4. Save datasets in JSON format |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import uuid |
|
|
from datetime import datetime |
|
|
from dataclasses import dataclass, field, asdict |
|
|
from typing import List, Dict, Any, Optional, Tuple |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import torchaudio |
|
|
from loguru import logger |
|
|
|
|
|
|
|
|
|
|
|
SUPPORTED_AUDIO_FORMATS = {'.wav', '.mp3', '.flac', '.ogg', '.opus'} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AudioSample: |
|
|
"""Represents a single audio sample with its metadata. |
|
|
|
|
|
Attributes: |
|
|
id: Unique identifier for the sample |
|
|
audio_path: Path to the audio file |
|
|
filename: Original filename |
|
|
caption: Generated or user-provided caption describing the music |
|
|
lyrics: Lyrics or "[Instrumental]" for instrumental tracks |
|
|
bpm: Beats per minute |
|
|
keyscale: Musical key (e.g., "C Major", "Am") |
|
|
timesignature: Time signature (e.g., "4" for 4/4) |
|
|
duration: Duration in seconds |
|
|
language: Vocal language or "instrumental" |
|
|
is_instrumental: Whether the track is instrumental |
|
|
custom_tag: User-defined activation tag for LoRA |
|
|
labeled: Whether the sample has been labeled |
|
|
""" |
|
|
id: str = "" |
|
|
audio_path: str = "" |
|
|
filename: str = "" |
|
|
caption: str = "" |
|
|
lyrics: str = "[Instrumental]" |
|
|
bpm: Optional[int] = None |
|
|
keyscale: str = "" |
|
|
timesignature: str = "" |
|
|
duration: float = 0.0 |
|
|
language: str = "instrumental" |
|
|
is_instrumental: bool = True |
|
|
custom_tag: str = "" |
|
|
labeled: bool = False |
|
|
|
|
|
def __post_init__(self): |
|
|
if not self.id: |
|
|
self.id = str(uuid.uuid4())[:8] |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert to dictionary.""" |
|
|
return asdict(self) |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, data: Dict[str, Any]) -> "AudioSample": |
|
|
"""Create from dictionary.""" |
|
|
return cls(**data) |
|
|
|
|
|
def get_full_caption(self, tag_position: str = "prepend") -> str: |
|
|
"""Get caption with custom tag applied. |
|
|
|
|
|
Args: |
|
|
tag_position: Where to place the custom tag ("prepend", "append", "replace") |
|
|
|
|
|
Returns: |
|
|
Caption with custom tag applied |
|
|
""" |
|
|
if not self.custom_tag: |
|
|
return self.caption |
|
|
|
|
|
if tag_position == "prepend": |
|
|
return f"{self.custom_tag}, {self.caption}" if self.caption else self.custom_tag |
|
|
elif tag_position == "append": |
|
|
return f"{self.caption}, {self.custom_tag}" if self.caption else self.custom_tag |
|
|
elif tag_position == "replace": |
|
|
return self.custom_tag |
|
|
else: |
|
|
return self.caption |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DatasetMetadata: |
|
|
"""Metadata for the entire dataset. |
|
|
|
|
|
Attributes: |
|
|
name: Dataset name |
|
|
custom_tag: Default custom tag for all samples |
|
|
tag_position: Where to place custom tag ("prepend", "append", "replace") |
|
|
created_at: Creation timestamp |
|
|
num_samples: Number of samples in the dataset |
|
|
all_instrumental: Whether all tracks are instrumental |
|
|
""" |
|
|
name: str = "untitled_dataset" |
|
|
custom_tag: str = "" |
|
|
tag_position: str = "prepend" |
|
|
created_at: str = "" |
|
|
num_samples: int = 0 |
|
|
all_instrumental: bool = True |
|
|
|
|
|
def __post_init__(self): |
|
|
if not self.created_at: |
|
|
self.created_at = datetime.now().isoformat() |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert to dictionary.""" |
|
|
return asdict(self) |
|
|
|
|
|
|
|
|
class DatasetBuilder: |
|
|
"""Builder for creating training datasets from audio files. |
|
|
|
|
|
This class handles: |
|
|
- Scanning directories for audio files |
|
|
- Auto-labeling using LLM |
|
|
- Managing sample metadata |
|
|
- Saving/loading datasets |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize the dataset builder.""" |
|
|
self.samples: List[AudioSample] = [] |
|
|
self.metadata = DatasetMetadata() |
|
|
self._current_dir: str = "" |
|
|
|
|
|
def scan_directory(self, directory: str) -> Tuple[List[AudioSample], str]: |
|
|
"""Scan a directory for audio files. |
|
|
|
|
|
Args: |
|
|
directory: Path to directory containing audio files |
|
|
|
|
|
Returns: |
|
|
Tuple of (list of AudioSample objects, status message) |
|
|
""" |
|
|
if not os.path.exists(directory): |
|
|
return [], f"❌ Directory not found: {directory}" |
|
|
|
|
|
if not os.path.isdir(directory): |
|
|
return [], f"❌ Not a directory: {directory}" |
|
|
|
|
|
self._current_dir = directory |
|
|
self.samples = [] |
|
|
|
|
|
|
|
|
audio_files = [] |
|
|
for root, dirs, files in os.walk(directory): |
|
|
for file in files: |
|
|
ext = os.path.splitext(file)[1].lower() |
|
|
if ext in SUPPORTED_AUDIO_FORMATS: |
|
|
audio_files.append(os.path.join(root, file)) |
|
|
|
|
|
if not audio_files: |
|
|
return [], f"❌ No audio files found in {directory}\nSupported formats: {', '.join(SUPPORTED_AUDIO_FORMATS)}" |
|
|
|
|
|
|
|
|
audio_files.sort() |
|
|
|
|
|
|
|
|
for audio_path in audio_files: |
|
|
try: |
|
|
|
|
|
duration = self._get_audio_duration(audio_path) |
|
|
|
|
|
sample = AudioSample( |
|
|
audio_path=audio_path, |
|
|
filename=os.path.basename(audio_path), |
|
|
duration=duration, |
|
|
is_instrumental=self.metadata.all_instrumental, |
|
|
custom_tag=self.metadata.custom_tag, |
|
|
) |
|
|
self.samples.append(sample) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to process {audio_path}: {e}") |
|
|
|
|
|
self.metadata.num_samples = len(self.samples) |
|
|
|
|
|
status = f"✅ Found {len(self.samples)} audio files in {directory}" |
|
|
return self.samples, status |
|
|
|
|
|
def _get_audio_duration(self, audio_path: str) -> float: |
|
|
"""Get the duration of an audio file in seconds. |
|
|
|
|
|
Args: |
|
|
audio_path: Path to audio file |
|
|
|
|
|
Returns: |
|
|
Duration in seconds |
|
|
""" |
|
|
try: |
|
|
info = torchaudio.info(audio_path) |
|
|
return info.num_frames / info.sample_rate |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to get duration for {audio_path}: {e}") |
|
|
return 0.0 |
|
|
|
|
|
def label_sample( |
|
|
self, |
|
|
sample_idx: int, |
|
|
dit_handler, |
|
|
llm_handler, |
|
|
progress_callback=None, |
|
|
) -> Tuple[AudioSample, str]: |
|
|
"""Label a single sample using the LLM. |
|
|
|
|
|
Args: |
|
|
sample_idx: Index of sample to label |
|
|
dit_handler: DiT handler for audio encoding |
|
|
llm_handler: LLM handler for caption generation |
|
|
progress_callback: Optional callback for progress updates |
|
|
|
|
|
Returns: |
|
|
Tuple of (updated AudioSample, status message) |
|
|
""" |
|
|
if sample_idx < 0 or sample_idx >= len(self.samples): |
|
|
return None, f"❌ Invalid sample index: {sample_idx}" |
|
|
|
|
|
sample = self.samples[sample_idx] |
|
|
|
|
|
try: |
|
|
if progress_callback: |
|
|
progress_callback(f"Processing: {sample.filename}") |
|
|
|
|
|
|
|
|
audio_codes = self._get_audio_codes(sample.audio_path, dit_handler) |
|
|
|
|
|
if not audio_codes: |
|
|
return sample, f"❌ Failed to encode audio: {sample.filename}" |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(f"Generating metadata for: {sample.filename}") |
|
|
|
|
|
|
|
|
metadata, status = llm_handler.understand_audio_from_codes( |
|
|
audio_codes=audio_codes, |
|
|
temperature=0.7, |
|
|
use_constrained_decoding=True, |
|
|
) |
|
|
|
|
|
if not metadata: |
|
|
return sample, f"❌ LLM labeling failed: {status}" |
|
|
|
|
|
|
|
|
sample.caption = metadata.get('caption', '') |
|
|
sample.bpm = self._parse_int(metadata.get('bpm')) |
|
|
sample.keyscale = metadata.get('keyscale', '') |
|
|
sample.timesignature = metadata.get('timesignature', '') |
|
|
sample.language = metadata.get('vocal_language', 'instrumental') |
|
|
|
|
|
|
|
|
if sample.is_instrumental: |
|
|
sample.lyrics = "[Instrumental]" |
|
|
sample.language = "instrumental" |
|
|
else: |
|
|
sample.lyrics = metadata.get('lyrics', '') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sample.labeled = True |
|
|
self.samples[sample_idx] = sample |
|
|
|
|
|
return sample, f"✅ Labeled: {sample.filename}" |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception(f"Error labeling sample {sample.filename}") |
|
|
return sample, f"❌ Error: {str(e)}" |
|
|
|
|
|
def label_all_samples( |
|
|
self, |
|
|
dit_handler, |
|
|
llm_handler, |
|
|
progress_callback=None, |
|
|
) -> Tuple[List[AudioSample], str]: |
|
|
"""Label all samples in the dataset. |
|
|
|
|
|
Args: |
|
|
dit_handler: DiT handler for audio encoding |
|
|
llm_handler: LLM handler for caption generation |
|
|
progress_callback: Optional callback for progress updates |
|
|
|
|
|
Returns: |
|
|
Tuple of (list of updated samples, status message) |
|
|
""" |
|
|
if not self.samples: |
|
|
return [], "❌ No samples to label. Please scan a directory first." |
|
|
|
|
|
success_count = 0 |
|
|
fail_count = 0 |
|
|
|
|
|
for i, sample in enumerate(self.samples): |
|
|
if progress_callback: |
|
|
progress_callback(f"Labeling {i+1}/{len(self.samples)}: {sample.filename}") |
|
|
|
|
|
_, status = self.label_sample(i, dit_handler, llm_handler, progress_callback) |
|
|
|
|
|
if "✅" in status: |
|
|
success_count += 1 |
|
|
else: |
|
|
fail_count += 1 |
|
|
|
|
|
status_msg = f"✅ Labeled {success_count}/{len(self.samples)} samples" |
|
|
if fail_count > 0: |
|
|
status_msg += f" ({fail_count} failed)" |
|
|
|
|
|
return self.samples, status_msg |
|
|
|
|
|
def _get_audio_codes(self, audio_path: str, dit_handler) -> Optional[str]: |
|
|
"""Encode audio to get semantic codes for LLM understanding. |
|
|
|
|
|
Args: |
|
|
audio_path: Path to audio file |
|
|
dit_handler: DiT handler with VAE and tokenizer |
|
|
|
|
|
Returns: |
|
|
Audio codes string or None if failed |
|
|
""" |
|
|
try: |
|
|
|
|
|
if not hasattr(dit_handler, 'convert_src_audio_to_codes'): |
|
|
logger.error("DiT handler missing convert_src_audio_to_codes method") |
|
|
return None |
|
|
|
|
|
|
|
|
codes_string = dit_handler.convert_src_audio_to_codes(audio_path) |
|
|
|
|
|
if codes_string and not codes_string.startswith("❌"): |
|
|
return codes_string |
|
|
else: |
|
|
logger.warning(f"Failed to convert audio to codes: {codes_string}") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception(f"Error encoding audio {audio_path}") |
|
|
return None |
|
|
|
|
|
def _parse_int(self, value: Any) -> Optional[int]: |
|
|
"""Safely parse an integer value.""" |
|
|
if value is None or value == "N/A" or value == "": |
|
|
return None |
|
|
try: |
|
|
return int(value) |
|
|
except (ValueError, TypeError): |
|
|
return None |
|
|
|
|
|
def update_sample(self, sample_idx: int, **kwargs) -> Tuple[AudioSample, str]: |
|
|
"""Update a sample's metadata. |
|
|
|
|
|
Args: |
|
|
sample_idx: Index of sample to update |
|
|
**kwargs: Fields to update |
|
|
|
|
|
Returns: |
|
|
Tuple of (updated sample, status message) |
|
|
""" |
|
|
if sample_idx < 0 or sample_idx >= len(self.samples): |
|
|
return None, f"❌ Invalid sample index: {sample_idx}" |
|
|
|
|
|
sample = self.samples[sample_idx] |
|
|
|
|
|
for key, value in kwargs.items(): |
|
|
if hasattr(sample, key): |
|
|
setattr(sample, key, value) |
|
|
|
|
|
self.samples[sample_idx] = sample |
|
|
return sample, f"✅ Updated: {sample.filename}" |
|
|
|
|
|
def set_custom_tag(self, custom_tag: str, tag_position: str = "prepend"): |
|
|
"""Set the custom tag for all samples. |
|
|
|
|
|
Args: |
|
|
custom_tag: Custom activation tag |
|
|
tag_position: Where to place tag ("prepend", "append", "replace") |
|
|
""" |
|
|
self.metadata.custom_tag = custom_tag |
|
|
self.metadata.tag_position = tag_position |
|
|
|
|
|
for sample in self.samples: |
|
|
sample.custom_tag = custom_tag |
|
|
|
|
|
def set_all_instrumental(self, is_instrumental: bool): |
|
|
"""Set instrumental flag for all samples. |
|
|
|
|
|
Args: |
|
|
is_instrumental: Whether all tracks are instrumental |
|
|
""" |
|
|
self.metadata.all_instrumental = is_instrumental |
|
|
|
|
|
for sample in self.samples: |
|
|
sample.is_instrumental = is_instrumental |
|
|
if is_instrumental: |
|
|
sample.lyrics = "[Instrumental]" |
|
|
sample.language = "instrumental" |
|
|
|
|
|
def get_sample_count(self) -> int: |
|
|
"""Get the number of samples in the dataset.""" |
|
|
return len(self.samples) |
|
|
|
|
|
def get_labeled_count(self) -> int: |
|
|
"""Get the number of labeled samples.""" |
|
|
return sum(1 for s in self.samples if s.labeled) |
|
|
|
|
|
def save_dataset(self, output_path: str, dataset_name: str = None) -> str: |
|
|
"""Save the dataset to a JSON file. |
|
|
|
|
|
Args: |
|
|
output_path: Path to save the dataset JSON |
|
|
dataset_name: Optional name for the dataset |
|
|
|
|
|
Returns: |
|
|
Status message |
|
|
""" |
|
|
if not self.samples: |
|
|
return "❌ No samples to save" |
|
|
|
|
|
if dataset_name: |
|
|
self.metadata.name = dataset_name |
|
|
|
|
|
self.metadata.num_samples = len(self.samples) |
|
|
self.metadata.created_at = datetime.now().isoformat() |
|
|
|
|
|
|
|
|
dataset = { |
|
|
"metadata": self.metadata.to_dict(), |
|
|
"samples": [] |
|
|
} |
|
|
|
|
|
for sample in self.samples: |
|
|
sample_dict = sample.to_dict() |
|
|
|
|
|
sample_dict["caption"] = sample.get_full_caption(self.metadata.tag_position) |
|
|
dataset["samples"].append(sample_dict) |
|
|
|
|
|
try: |
|
|
|
|
|
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True) |
|
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(dataset, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
return f"✅ Dataset saved to {output_path}\n{len(self.samples)} samples, tag: '{self.metadata.custom_tag}'" |
|
|
except Exception as e: |
|
|
logger.exception("Error saving dataset") |
|
|
return f"❌ Failed to save dataset: {str(e)}" |
|
|
|
|
|
def load_dataset(self, dataset_path: str) -> Tuple[List[AudioSample], str]: |
|
|
"""Load a dataset from a JSON file. |
|
|
|
|
|
Args: |
|
|
dataset_path: Path to the dataset JSON file |
|
|
|
|
|
Returns: |
|
|
Tuple of (list of samples, status message) |
|
|
""" |
|
|
if not os.path.exists(dataset_path): |
|
|
return [], f"❌ Dataset not found: {dataset_path}" |
|
|
|
|
|
try: |
|
|
with open(dataset_path, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
|
|
|
if "metadata" in data: |
|
|
meta_dict = data["metadata"] |
|
|
self.metadata = DatasetMetadata( |
|
|
name=meta_dict.get("name", "untitled"), |
|
|
custom_tag=meta_dict.get("custom_tag", ""), |
|
|
tag_position=meta_dict.get("tag_position", "prepend"), |
|
|
created_at=meta_dict.get("created_at", ""), |
|
|
num_samples=meta_dict.get("num_samples", 0), |
|
|
all_instrumental=meta_dict.get("all_instrumental", True), |
|
|
) |
|
|
|
|
|
|
|
|
self.samples = [] |
|
|
for sample_dict in data.get("samples", []): |
|
|
sample = AudioSample.from_dict(sample_dict) |
|
|
self.samples.append(sample) |
|
|
|
|
|
return self.samples, f"✅ Loaded {len(self.samples)} samples from {dataset_path}" |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception("Error loading dataset") |
|
|
return [], f"❌ Failed to load dataset: {str(e)}" |
|
|
|
|
|
def get_samples_dataframe_data(self) -> List[List[Any]]: |
|
|
"""Get samples data in a format suitable for Gradio DataFrame. |
|
|
|
|
|
Returns: |
|
|
List of rows for DataFrame display |
|
|
""" |
|
|
rows = [] |
|
|
for i, sample in enumerate(self.samples): |
|
|
rows.append([ |
|
|
i, |
|
|
sample.filename, |
|
|
f"{sample.duration:.1f}s", |
|
|
"✅" if sample.labeled else "❌", |
|
|
sample.bpm or "-", |
|
|
sample.keyscale or "-", |
|
|
sample.caption[:50] + "..." if len(sample.caption) > 50 else sample.caption or "-", |
|
|
]) |
|
|
return rows |
|
|
|
|
|
def to_training_format(self) -> List[Dict[str, Any]]: |
|
|
"""Convert dataset to format suitable for training. |
|
|
|
|
|
Returns: |
|
|
List of training sample dictionaries |
|
|
""" |
|
|
training_samples = [] |
|
|
|
|
|
for sample in self.samples: |
|
|
if not sample.labeled: |
|
|
continue |
|
|
|
|
|
training_sample = { |
|
|
"audio_path": sample.audio_path, |
|
|
"caption": sample.get_full_caption(self.metadata.tag_position), |
|
|
"lyrics": sample.lyrics, |
|
|
"bpm": sample.bpm, |
|
|
"keyscale": sample.keyscale, |
|
|
"timesignature": sample.timesignature, |
|
|
"duration": sample.duration, |
|
|
"language": sample.language, |
|
|
"is_instrumental": sample.is_instrumental, |
|
|
} |
|
|
training_samples.append(training_sample) |
|
|
|
|
|
return training_samples |
|
|
|
|
|
def preprocess_to_tensors( |
|
|
self, |
|
|
dit_handler, |
|
|
output_dir: str, |
|
|
max_duration: float = 240.0, |
|
|
progress_callback=None, |
|
|
) -> Tuple[List[str], str]: |
|
|
"""Preprocess all labeled samples to tensor files for efficient training. |
|
|
|
|
|
This method pre-computes all tensors needed by the DiT decoder: |
|
|
- target_latents: VAE-encoded audio |
|
|
- encoder_hidden_states: Condition encoder output |
|
|
- context_latents: Source context (silence_latent + zeros for text2music) |
|
|
|
|
|
Args: |
|
|
dit_handler: Initialized DiT handler with model, VAE, and text encoder |
|
|
output_dir: Directory to save preprocessed .pt files |
|
|
max_duration: Maximum audio duration in seconds (default 240s = 4 min) |
|
|
progress_callback: Optional callback for progress updates |
|
|
|
|
|
Returns: |
|
|
Tuple of (list of output paths, status message) |
|
|
""" |
|
|
if not self.samples: |
|
|
return [], "❌ No samples to preprocess" |
|
|
|
|
|
labeled_samples = [s for s in self.samples if s.labeled] |
|
|
if not labeled_samples: |
|
|
return [], "❌ No labeled samples to preprocess" |
|
|
|
|
|
|
|
|
if dit_handler is None or dit_handler.model is None: |
|
|
return [], "❌ Model not initialized. Please initialize the service first." |
|
|
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
output_paths = [] |
|
|
success_count = 0 |
|
|
fail_count = 0 |
|
|
|
|
|
|
|
|
model = dit_handler.model |
|
|
vae = dit_handler.vae |
|
|
text_encoder = dit_handler.text_encoder |
|
|
text_tokenizer = dit_handler.text_tokenizer |
|
|
silence_latent = dit_handler.silence_latent |
|
|
device = dit_handler.device |
|
|
dtype = dit_handler.dtype |
|
|
|
|
|
target_sample_rate = 48000 |
|
|
|
|
|
for i, sample in enumerate(labeled_samples): |
|
|
try: |
|
|
if progress_callback: |
|
|
progress_callback(f"Preprocessing {i+1}/{len(labeled_samples)}: {sample.filename}") |
|
|
|
|
|
|
|
|
audio, sr = torchaudio.load(sample.audio_path) |
|
|
|
|
|
|
|
|
if sr != target_sample_rate: |
|
|
resampler = torchaudio.transforms.Resample(sr, target_sample_rate) |
|
|
audio = resampler(audio) |
|
|
|
|
|
|
|
|
if audio.shape[0] == 1: |
|
|
audio = audio.repeat(2, 1) |
|
|
elif audio.shape[0] > 2: |
|
|
audio = audio[:2, :] |
|
|
|
|
|
|
|
|
max_samples = int(max_duration * target_sample_rate) |
|
|
if audio.shape[1] > max_samples: |
|
|
audio = audio[:, :max_samples] |
|
|
|
|
|
|
|
|
audio = audio.unsqueeze(0).to(device).to(vae.dtype) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
latent = vae.encode(audio).latent_dist.sample() |
|
|
|
|
|
target_latents = latent.transpose(1, 2).to(dtype) |
|
|
|
|
|
latent_length = target_latents.shape[1] |
|
|
|
|
|
|
|
|
attention_mask = torch.ones(1, latent_length, device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
caption = sample.get_full_caption(self.metadata.tag_position) |
|
|
text_inputs = text_tokenizer( |
|
|
caption, |
|
|
padding="max_length", |
|
|
max_length=256, |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
text_input_ids = text_inputs.input_ids.to(device) |
|
|
text_attention_mask = text_inputs.attention_mask.to(device).to(dtype) |
|
|
|
|
|
with torch.no_grad(): |
|
|
text_outputs = text_encoder(text_input_ids) |
|
|
text_hidden_states = text_outputs.last_hidden_state.to(dtype) |
|
|
|
|
|
|
|
|
lyrics = sample.lyrics if sample.lyrics else "[Instrumental]" |
|
|
lyric_inputs = text_tokenizer( |
|
|
lyrics, |
|
|
padding="max_length", |
|
|
max_length=512, |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
lyric_input_ids = lyric_inputs.input_ids.to(device) |
|
|
lyric_attention_mask = lyric_inputs.attention_mask.to(device).to(dtype) |
|
|
|
|
|
with torch.no_grad(): |
|
|
lyric_hidden_states = text_encoder.embed_tokens(lyric_input_ids).to(dtype) |
|
|
|
|
|
|
|
|
|
|
|
refer_audio_hidden = torch.zeros(1, 1, 64, device=device, dtype=dtype) |
|
|
refer_audio_order_mask = torch.zeros(1, device=device, dtype=torch.long) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
encoder_hidden_states, encoder_attention_mask = model.encoder( |
|
|
text_hidden_states=text_hidden_states, |
|
|
text_attention_mask=text_attention_mask, |
|
|
lyric_hidden_states=lyric_hidden_states, |
|
|
lyric_attention_mask=lyric_attention_mask, |
|
|
refer_audio_acoustic_hidden_states_packed=refer_audio_hidden, |
|
|
refer_audio_order_mask=refer_audio_order_mask, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src_latents = silence_latent[:, :latent_length, :].to(dtype) |
|
|
if src_latents.shape[0] < 1: |
|
|
src_latents = src_latents.expand(1, -1, -1) |
|
|
|
|
|
|
|
|
if src_latents.shape[1] < latent_length: |
|
|
pad_len = latent_length - src_latents.shape[1] |
|
|
src_latents = torch.cat([ |
|
|
src_latents, |
|
|
silence_latent[:, :pad_len, :].expand(1, -1, -1).to(dtype) |
|
|
], dim=1) |
|
|
elif src_latents.shape[1] > latent_length: |
|
|
src_latents = src_latents[:, :latent_length, :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chunk_masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype) |
|
|
|
|
|
context_latents = torch.cat([src_latents, chunk_masks], dim=-1) |
|
|
|
|
|
|
|
|
output_data = { |
|
|
"target_latents": target_latents.squeeze(0).cpu(), |
|
|
"attention_mask": attention_mask.squeeze(0).cpu(), |
|
|
"encoder_hidden_states": encoder_hidden_states.squeeze(0).cpu(), |
|
|
"encoder_attention_mask": encoder_attention_mask.squeeze(0).cpu(), |
|
|
"context_latents": context_latents.squeeze(0).cpu(), |
|
|
"metadata": { |
|
|
"audio_path": sample.audio_path, |
|
|
"filename": sample.filename, |
|
|
"caption": caption, |
|
|
"lyrics": lyrics, |
|
|
"duration": sample.duration, |
|
|
"bpm": sample.bpm, |
|
|
"keyscale": sample.keyscale, |
|
|
"timesignature": sample.timesignature, |
|
|
"language": sample.language, |
|
|
"is_instrumental": sample.is_instrumental, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
output_path = os.path.join(output_dir, f"{sample.id}.pt") |
|
|
torch.save(output_data, output_path) |
|
|
output_paths.append(output_path) |
|
|
success_count += 1 |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception(f"Error preprocessing {sample.filename}") |
|
|
fail_count += 1 |
|
|
if progress_callback: |
|
|
progress_callback(f"❌ Failed: {sample.filename}: {str(e)}") |
|
|
|
|
|
|
|
|
manifest = { |
|
|
"metadata": self.metadata.to_dict(), |
|
|
"samples": output_paths, |
|
|
"num_samples": len(output_paths), |
|
|
} |
|
|
manifest_path = os.path.join(output_dir, "manifest.json") |
|
|
with open(manifest_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(manifest, f, indent=2) |
|
|
|
|
|
status = f"✅ Preprocessed {success_count}/{len(labeled_samples)} samples to {output_dir}" |
|
|
if fail_count > 0: |
|
|
status += f" ({fail_count} failed)" |
|
|
|
|
|
return output_paths, status |
|
|
|