File size: 29,346 Bytes
9f5c8f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 |
"""
Dataset Builder for LoRA Training
Provides functionality to:
1. Scan directories for audio files
2. Auto-label audio using LLM
3. Preview and edit metadata
4. Save datasets in JSON format
"""
import os
import json
import uuid
from datetime import datetime
from dataclasses import dataclass, field, asdict
from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path
import torch
import torchaudio
from loguru import logger
# Supported audio formats
SUPPORTED_AUDIO_FORMATS = {'.wav', '.mp3', '.flac', '.ogg', '.opus'}
@dataclass
class AudioSample:
"""Represents a single audio sample with its metadata.
Attributes:
id: Unique identifier for the sample
audio_path: Path to the audio file
filename: Original filename
caption: Generated or user-provided caption describing the music
lyrics: Lyrics or "[Instrumental]" for instrumental tracks
bpm: Beats per minute
keyscale: Musical key (e.g., "C Major", "Am")
timesignature: Time signature (e.g., "4" for 4/4)
duration: Duration in seconds
language: Vocal language or "instrumental"
is_instrumental: Whether the track is instrumental
custom_tag: User-defined activation tag for LoRA
labeled: Whether the sample has been labeled
"""
id: str = ""
audio_path: str = ""
filename: str = ""
caption: str = ""
lyrics: str = "[Instrumental]"
bpm: Optional[int] = None
keyscale: str = ""
timesignature: str = ""
duration: float = 0.0
language: str = "instrumental"
is_instrumental: bool = True
custom_tag: str = ""
labeled: bool = False
def __post_init__(self):
if not self.id:
self.id = str(uuid.uuid4())[:8]
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AudioSample":
"""Create from dictionary."""
return cls(**data)
def get_full_caption(self, tag_position: str = "prepend") -> str:
"""Get caption with custom tag applied.
Args:
tag_position: Where to place the custom tag ("prepend", "append", "replace")
Returns:
Caption with custom tag applied
"""
if not self.custom_tag:
return self.caption
if tag_position == "prepend":
return f"{self.custom_tag}, {self.caption}" if self.caption else self.custom_tag
elif tag_position == "append":
return f"{self.caption}, {self.custom_tag}" if self.caption else self.custom_tag
elif tag_position == "replace":
return self.custom_tag
else:
return self.caption
@dataclass
class DatasetMetadata:
"""Metadata for the entire dataset.
Attributes:
name: Dataset name
custom_tag: Default custom tag for all samples
tag_position: Where to place custom tag ("prepend", "append", "replace")
created_at: Creation timestamp
num_samples: Number of samples in the dataset
all_instrumental: Whether all tracks are instrumental
"""
name: str = "untitled_dataset"
custom_tag: str = ""
tag_position: str = "prepend"
created_at: str = ""
num_samples: int = 0
all_instrumental: bool = True
def __post_init__(self):
if not self.created_at:
self.created_at = datetime.now().isoformat()
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return asdict(self)
class DatasetBuilder:
"""Builder for creating training datasets from audio files.
This class handles:
- Scanning directories for audio files
- Auto-labeling using LLM
- Managing sample metadata
- Saving/loading datasets
"""
def __init__(self):
"""Initialize the dataset builder."""
self.samples: List[AudioSample] = []
self.metadata = DatasetMetadata()
self._current_dir: str = ""
def scan_directory(self, directory: str) -> Tuple[List[AudioSample], str]:
"""Scan a directory for audio files.
Args:
directory: Path to directory containing audio files
Returns:
Tuple of (list of AudioSample objects, status message)
"""
if not os.path.exists(directory):
return [], f"β Directory not found: {directory}"
if not os.path.isdir(directory):
return [], f"β Not a directory: {directory}"
self._current_dir = directory
self.samples = []
# Scan for audio files
audio_files = []
for root, dirs, files in os.walk(directory):
for file in files:
ext = os.path.splitext(file)[1].lower()
if ext in SUPPORTED_AUDIO_FORMATS:
audio_files.append(os.path.join(root, file))
if not audio_files:
return [], f"β No audio files found in {directory}\nSupported formats: {', '.join(SUPPORTED_AUDIO_FORMATS)}"
# Sort files by name
audio_files.sort()
# Create AudioSample objects
for audio_path in audio_files:
try:
# Get duration
duration = self._get_audio_duration(audio_path)
sample = AudioSample(
audio_path=audio_path,
filename=os.path.basename(audio_path),
duration=duration,
is_instrumental=self.metadata.all_instrumental,
custom_tag=self.metadata.custom_tag,
)
self.samples.append(sample)
except Exception as e:
logger.warning(f"Failed to process {audio_path}: {e}")
self.metadata.num_samples = len(self.samples)
status = f"β
Found {len(self.samples)} audio files in {directory}"
return self.samples, status
def _get_audio_duration(self, audio_path: str) -> float:
"""Get the duration of an audio file in seconds.
Args:
audio_path: Path to audio file
Returns:
Duration in seconds
"""
try:
info = torchaudio.info(audio_path)
return info.num_frames / info.sample_rate
except Exception as e:
logger.warning(f"Failed to get duration for {audio_path}: {e}")
return 0.0
def label_sample(
self,
sample_idx: int,
dit_handler,
llm_handler,
progress_callback=None,
) -> Tuple[AudioSample, str]:
"""Label a single sample using the LLM.
Args:
sample_idx: Index of sample to label
dit_handler: DiT handler for audio encoding
llm_handler: LLM handler for caption generation
progress_callback: Optional callback for progress updates
Returns:
Tuple of (updated AudioSample, status message)
"""
if sample_idx < 0 or sample_idx >= len(self.samples):
return None, f"β Invalid sample index: {sample_idx}"
sample = self.samples[sample_idx]
try:
if progress_callback:
progress_callback(f"Processing: {sample.filename}")
# Step 1: Load and encode audio to get audio codes
audio_codes = self._get_audio_codes(sample.audio_path, dit_handler)
if not audio_codes:
return sample, f"β Failed to encode audio: {sample.filename}"
if progress_callback:
progress_callback(f"Generating metadata for: {sample.filename}")
# Step 2: Use LLM to understand the audio
metadata, status = llm_handler.understand_audio_from_codes(
audio_codes=audio_codes,
temperature=0.7,
use_constrained_decoding=True,
)
if not metadata:
return sample, f"β LLM labeling failed: {status}"
# Step 3: Update sample with generated metadata
sample.caption = metadata.get('caption', '')
sample.bpm = self._parse_int(metadata.get('bpm'))
sample.keyscale = metadata.get('keyscale', '')
sample.timesignature = metadata.get('timesignature', '')
sample.language = metadata.get('vocal_language', 'instrumental')
# Handle lyrics based on instrumental flag
if sample.is_instrumental:
sample.lyrics = "[Instrumental]"
sample.language = "instrumental"
else:
sample.lyrics = metadata.get('lyrics', '')
# NOTE: Duration is NOT overwritten from LM metadata.
# We keep the real audio duration obtained from torchaudio during scan.
sample.labeled = True
self.samples[sample_idx] = sample
return sample, f"β
Labeled: {sample.filename}"
except Exception as e:
logger.exception(f"Error labeling sample {sample.filename}")
return sample, f"β Error: {str(e)}"
def label_all_samples(
self,
dit_handler,
llm_handler,
progress_callback=None,
) -> Tuple[List[AudioSample], str]:
"""Label all samples in the dataset.
Args:
dit_handler: DiT handler for audio encoding
llm_handler: LLM handler for caption generation
progress_callback: Optional callback for progress updates
Returns:
Tuple of (list of updated samples, status message)
"""
if not self.samples:
return [], "β No samples to label. Please scan a directory first."
success_count = 0
fail_count = 0
for i, sample in enumerate(self.samples):
if progress_callback:
progress_callback(f"Labeling {i+1}/{len(self.samples)}: {sample.filename}")
_, status = self.label_sample(i, dit_handler, llm_handler, progress_callback)
if "β
" in status:
success_count += 1
else:
fail_count += 1
status_msg = f"β
Labeled {success_count}/{len(self.samples)} samples"
if fail_count > 0:
status_msg += f" ({fail_count} failed)"
return self.samples, status_msg
def _get_audio_codes(self, audio_path: str, dit_handler) -> Optional[str]:
"""Encode audio to get semantic codes for LLM understanding.
Args:
audio_path: Path to audio file
dit_handler: DiT handler with VAE and tokenizer
Returns:
Audio codes string or None if failed
"""
try:
# Check if handler has required methods
if not hasattr(dit_handler, 'convert_src_audio_to_codes'):
logger.error("DiT handler missing convert_src_audio_to_codes method")
return None
# Use handler's method to convert audio to codes
codes_string = dit_handler.convert_src_audio_to_codes(audio_path)
if codes_string and not codes_string.startswith("β"):
return codes_string
else:
logger.warning(f"Failed to convert audio to codes: {codes_string}")
return None
except Exception as e:
logger.exception(f"Error encoding audio {audio_path}")
return None
def _parse_int(self, value: Any) -> Optional[int]:
"""Safely parse an integer value."""
if value is None or value == "N/A" or value == "":
return None
try:
return int(value)
except (ValueError, TypeError):
return None
def update_sample(self, sample_idx: int, **kwargs) -> Tuple[AudioSample, str]:
"""Update a sample's metadata.
Args:
sample_idx: Index of sample to update
**kwargs: Fields to update
Returns:
Tuple of (updated sample, status message)
"""
if sample_idx < 0 or sample_idx >= len(self.samples):
return None, f"β Invalid sample index: {sample_idx}"
sample = self.samples[sample_idx]
for key, value in kwargs.items():
if hasattr(sample, key):
setattr(sample, key, value)
self.samples[sample_idx] = sample
return sample, f"β
Updated: {sample.filename}"
def set_custom_tag(self, custom_tag: str, tag_position: str = "prepend"):
"""Set the custom tag for all samples.
Args:
custom_tag: Custom activation tag
tag_position: Where to place tag ("prepend", "append", "replace")
"""
self.metadata.custom_tag = custom_tag
self.metadata.tag_position = tag_position
for sample in self.samples:
sample.custom_tag = custom_tag
def set_all_instrumental(self, is_instrumental: bool):
"""Set instrumental flag for all samples.
Args:
is_instrumental: Whether all tracks are instrumental
"""
self.metadata.all_instrumental = is_instrumental
for sample in self.samples:
sample.is_instrumental = is_instrumental
if is_instrumental:
sample.lyrics = "[Instrumental]"
sample.language = "instrumental"
def get_sample_count(self) -> int:
"""Get the number of samples in the dataset."""
return len(self.samples)
def get_labeled_count(self) -> int:
"""Get the number of labeled samples."""
return sum(1 for s in self.samples if s.labeled)
def save_dataset(self, output_path: str, dataset_name: str = None) -> str:
"""Save the dataset to a JSON file.
Args:
output_path: Path to save the dataset JSON
dataset_name: Optional name for the dataset
Returns:
Status message
"""
if not self.samples:
return "β No samples to save"
if dataset_name:
self.metadata.name = dataset_name
self.metadata.num_samples = len(self.samples)
self.metadata.created_at = datetime.now().isoformat()
# Build dataset with captions that include custom tags
dataset = {
"metadata": self.metadata.to_dict(),
"samples": []
}
for sample in self.samples:
sample_dict = sample.to_dict()
# Apply custom tag to caption based on position
sample_dict["caption"] = sample.get_full_caption(self.metadata.tag_position)
dataset["samples"].append(sample_dict)
try:
# Ensure output directory exists
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(dataset, f, indent=2, ensure_ascii=False)
return f"β
Dataset saved to {output_path}\n{len(self.samples)} samples, tag: '{self.metadata.custom_tag}'"
except Exception as e:
logger.exception("Error saving dataset")
return f"β Failed to save dataset: {str(e)}"
def load_dataset(self, dataset_path: str) -> Tuple[List[AudioSample], str]:
"""Load a dataset from a JSON file.
Args:
dataset_path: Path to the dataset JSON file
Returns:
Tuple of (list of samples, status message)
"""
if not os.path.exists(dataset_path):
return [], f"β Dataset not found: {dataset_path}"
try:
with open(dataset_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Load metadata
if "metadata" in data:
meta_dict = data["metadata"]
self.metadata = DatasetMetadata(
name=meta_dict.get("name", "untitled"),
custom_tag=meta_dict.get("custom_tag", ""),
tag_position=meta_dict.get("tag_position", "prepend"),
created_at=meta_dict.get("created_at", ""),
num_samples=meta_dict.get("num_samples", 0),
all_instrumental=meta_dict.get("all_instrumental", True),
)
# Load samples
self.samples = []
for sample_dict in data.get("samples", []):
sample = AudioSample.from_dict(sample_dict)
self.samples.append(sample)
return self.samples, f"β
Loaded {len(self.samples)} samples from {dataset_path}"
except Exception as e:
logger.exception("Error loading dataset")
return [], f"β Failed to load dataset: {str(e)}"
def get_samples_dataframe_data(self) -> List[List[Any]]:
"""Get samples data in a format suitable for Gradio DataFrame.
Returns:
List of rows for DataFrame display
"""
rows = []
for i, sample in enumerate(self.samples):
rows.append([
i,
sample.filename,
f"{sample.duration:.1f}s",
"β
" if sample.labeled else "β",
sample.bpm or "-",
sample.keyscale or "-",
sample.caption[:50] + "..." if len(sample.caption) > 50 else sample.caption or "-",
])
return rows
def to_training_format(self) -> List[Dict[str, Any]]:
"""Convert dataset to format suitable for training.
Returns:
List of training sample dictionaries
"""
training_samples = []
for sample in self.samples:
if not sample.labeled:
continue
training_sample = {
"audio_path": sample.audio_path,
"caption": sample.get_full_caption(self.metadata.tag_position),
"lyrics": sample.lyrics,
"bpm": sample.bpm,
"keyscale": sample.keyscale,
"timesignature": sample.timesignature,
"duration": sample.duration,
"language": sample.language,
"is_instrumental": sample.is_instrumental,
}
training_samples.append(training_sample)
return training_samples
def preprocess_to_tensors(
self,
dit_handler,
output_dir: str,
max_duration: float = 240.0,
progress_callback=None,
) -> Tuple[List[str], str]:
"""Preprocess all labeled samples to tensor files for efficient training.
This method pre-computes all tensors needed by the DiT decoder:
- target_latents: VAE-encoded audio
- encoder_hidden_states: Condition encoder output
- context_latents: Source context (silence_latent + zeros for text2music)
Args:
dit_handler: Initialized DiT handler with model, VAE, and text encoder
output_dir: Directory to save preprocessed .pt files
max_duration: Maximum audio duration in seconds (default 240s = 4 min)
progress_callback: Optional callback for progress updates
Returns:
Tuple of (list of output paths, status message)
"""
if not self.samples:
return [], "β No samples to preprocess"
labeled_samples = [s for s in self.samples if s.labeled]
if not labeled_samples:
return [], "β No labeled samples to preprocess"
# Validate handler
if dit_handler is None or dit_handler.model is None:
return [], "β Model not initialized. Please initialize the service first."
# Create output directory
os.makedirs(output_dir, exist_ok=True)
output_paths = []
success_count = 0
fail_count = 0
# Get model and components
model = dit_handler.model
vae = dit_handler.vae
text_encoder = dit_handler.text_encoder
text_tokenizer = dit_handler.text_tokenizer
silence_latent = dit_handler.silence_latent
device = dit_handler.device
dtype = dit_handler.dtype
target_sample_rate = 48000
for i, sample in enumerate(labeled_samples):
try:
if progress_callback:
progress_callback(f"Preprocessing {i+1}/{len(labeled_samples)}: {sample.filename}")
# Step 1: Load and preprocess audio to stereo @ 48kHz
audio, sr = torchaudio.load(sample.audio_path)
# Resample if needed
if sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
audio = resampler(audio)
# Convert to stereo
if audio.shape[0] == 1:
audio = audio.repeat(2, 1)
elif audio.shape[0] > 2:
audio = audio[:2, :]
# Truncate to max duration
max_samples = int(max_duration * target_sample_rate)
if audio.shape[1] > max_samples:
audio = audio[:, :max_samples]
# Add batch dimension: [2, T] -> [1, 2, T]
audio = audio.unsqueeze(0).to(device).to(vae.dtype)
# Step 2: VAE encode audio to get target_latents
with torch.no_grad():
latent = vae.encode(audio).latent_dist.sample()
# [1, 64, T_latent] -> [1, T_latent, 64]
target_latents = latent.transpose(1, 2).to(dtype)
latent_length = target_latents.shape[1]
# Step 3: Create attention mask (all ones for valid audio)
attention_mask = torch.ones(1, latent_length, device=device, dtype=dtype)
# Step 4: Encode caption text
caption = sample.get_full_caption(self.metadata.tag_position)
text_inputs = text_tokenizer(
caption,
padding="max_length",
max_length=256,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
text_attention_mask = text_inputs.attention_mask.to(device).to(dtype)
with torch.no_grad():
text_outputs = text_encoder(text_input_ids)
text_hidden_states = text_outputs.last_hidden_state.to(dtype)
# Step 5: Encode lyrics
lyrics = sample.lyrics if sample.lyrics else "[Instrumental]"
lyric_inputs = text_tokenizer(
lyrics,
padding="max_length",
max_length=512,
truncation=True,
return_tensors="pt",
)
lyric_input_ids = lyric_inputs.input_ids.to(device)
lyric_attention_mask = lyric_inputs.attention_mask.to(device).to(dtype)
with torch.no_grad():
lyric_hidden_states = text_encoder.embed_tokens(lyric_input_ids).to(dtype)
# Step 6: Prepare refer_audio (empty for text2music)
# Create minimal refer_audio placeholder
refer_audio_hidden = torch.zeros(1, 1, 64, device=device, dtype=dtype)
refer_audio_order_mask = torch.zeros(1, device=device, dtype=torch.long)
# Step 7: Run model.encoder to get encoder_hidden_states
with torch.no_grad():
encoder_hidden_states, encoder_attention_mask = model.encoder(
text_hidden_states=text_hidden_states,
text_attention_mask=text_attention_mask,
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=refer_audio_hidden,
refer_audio_order_mask=refer_audio_order_mask,
)
# Step 8: Build context_latents for text2music
# For text2music: src_latents = silence_latent, is_covers = 0
# chunk_masks: 1 = generate, 0 = keep original
# IMPORTANT: chunk_masks must have same shape as src_latents [B, T, 64]
# For text2music, we want to generate the entire audio, so chunk_masks = all 1s
src_latents = silence_latent[:, :latent_length, :].to(dtype)
if src_latents.shape[0] < 1:
src_latents = src_latents.expand(1, -1, -1)
# Pad or truncate silence_latent to match latent_length
if src_latents.shape[1] < latent_length:
pad_len = latent_length - src_latents.shape[1]
src_latents = torch.cat([
src_latents,
silence_latent[:, :pad_len, :].expand(1, -1, -1).to(dtype)
], dim=1)
elif src_latents.shape[1] > latent_length:
src_latents = src_latents[:, :latent_length, :]
# chunk_masks = 1 means "generate this region", 0 = keep original
# Shape must match src_latents: [B, T, 64] (NOT [B, T, 1])
# For text2music, generate everything -> all 1s with shape [1, T, 64]
chunk_masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype)
# context_latents = [src_latents, chunk_masks] -> [B, T, 128]
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
# Step 9: Save all tensors to .pt file (squeeze batch dimension for storage)
output_data = {
"target_latents": target_latents.squeeze(0).cpu(), # [T, 64]
"attention_mask": attention_mask.squeeze(0).cpu(), # [T]
"encoder_hidden_states": encoder_hidden_states.squeeze(0).cpu(), # [L, D]
"encoder_attention_mask": encoder_attention_mask.squeeze(0).cpu(), # [L]
"context_latents": context_latents.squeeze(0).cpu(), # [T, 65]
"metadata": {
"audio_path": sample.audio_path,
"filename": sample.filename,
"caption": caption,
"lyrics": lyrics,
"duration": sample.duration,
"bpm": sample.bpm,
"keyscale": sample.keyscale,
"timesignature": sample.timesignature,
"language": sample.language,
"is_instrumental": sample.is_instrumental,
}
}
# Save with sample ID as filename
output_path = os.path.join(output_dir, f"{sample.id}.pt")
torch.save(output_data, output_path)
output_paths.append(output_path)
success_count += 1
except Exception as e:
logger.exception(f"Error preprocessing {sample.filename}")
fail_count += 1
if progress_callback:
progress_callback(f"β Failed: {sample.filename}: {str(e)}")
# Save manifest file listing all preprocessed samples
manifest = {
"metadata": self.metadata.to_dict(),
"samples": output_paths,
"num_samples": len(output_paths),
}
manifest_path = os.path.join(output_dir, "manifest.json")
with open(manifest_path, 'w', encoding='utf-8') as f:
json.dump(manifest, f, indent=2)
status = f"β
Preprocessed {success_count}/{len(labeled_samples)} samples to {output_dir}"
if fail_count > 0:
status += f" ({fail_count} failed)"
return output_paths, status
|