import os import random import torch import pandas as pd import glob from torch.utils.data import Dataset from torch.nn.utils.rnn import pad_sequence from src.utils import setup_logger logger = setup_logger(__name__) class ChatterboxDataset(Dataset): def __init__(self, config, split="train"): """ Args: config: Training configuration split: "train", "val", or "all" (no split) """ self.cfg = config self.preprocessed_dir = config.preprocessed_dir self.split = split # List all .pt files recursively if not os.path.exists(self.preprocessed_dir): raise FileNotFoundError(f"Preprocessing folder not found: {self.preprocessed_dir}.") pattern = os.path.join(self.preprocessed_dir, "**", "*.pt") all_files_full = glob.glob(pattern, recursive=True) # Store relative paths to the preprocessed directory, normalized for consistent matching all_files = sorted([os.path.normpath(os.path.relpath(f, self.preprocessed_dir)) for f in all_files_full]) if len(all_files) == 0: raise RuntimeError(f"There are no .pt files in the folder (including subdirectories): {self.preprocessed_dir}") # --- Speaker-Aware Splitting & Filtering Logic --- try: # 1. Load mappings # metadata.csv: wav_path|raw_text|norm_text meta = pd.read_csv(config.csv_path, sep="|", header=None, quoting=3) # attribution: audio_file,resolved_path,text,speaker_id,... attr = pd.read_csv(config.attribution_path) # 2. Build filename -> speaker_id mapping and collect metadata for filtering # We know meta and attr are in the same order file_to_speaker = {} file_to_meta = {} # For traceability for i in range(len(meta)): wav_filename = str(meta.iloc[i, 0]) # Convert wav filename to pt filename while preserving structure pt_filename = wav_filename if pt_filename.endswith(".wav"): pt_filename = pt_filename[:-4] + ".pt" elif not pt_filename.endswith(".pt"): pt_filename += ".pt" # Normalize path for consistent matching pt_filename = os.path.normpath(pt_filename) speaker_id = str(attr.iloc[i]["speaker_id"]) file_to_speaker[pt_filename] = speaker_id # Store duration and SNR for filtering logic file_to_meta[pt_filename] = { "speaker_id": speaker_id, "duration": float(attr.iloc[i].get("duration", 0)), "snr": float(attr.iloc[i].get("snr", 0)) } # 3. Filter OOD speakers and low-quality samples ood_speakers = set(getattr(config, "ood_speakers", [])) min_duration = getattr(config, "min_training_duration", 4.0) min_snr = getattr(config, "min_training_snr", 20.0) max_snr = getattr(config, "max_training_snr", 100.0) lineage_data = [] # Group files by speaker_id speaker_to_files = {} for f in all_files: meta_info = file_to_meta.get(f) if meta_info is None: continue spk_id = meta_info["speaker_id"] duration = meta_info["duration"] snr = meta_info["snr"] reason = None if spk_id in ood_speakers: reason = "OOD_SPEAKER" elif duration < min_duration: reason = "LOW_DURATION" elif snr < min_snr: reason = "LOW_SNR" elif snr > max_snr: reason = "HIGH_SNR" if reason: lineage_data.append({ "file": f, "speaker_id": spk_id, "duration": duration, "snr": snr, "reason": reason }) continue # Exclude from training/validation if spk_id not in speaker_to_files: speaker_to_files[spk_id] = [] speaker_to_files[spk_id].append(f) # Save lineage if this is the first initialization (e.g. for "train" split) if self.split == "train": lineage_df = pd.DataFrame(lineage_data) lineage_path = os.path.join(config.output_dir, "dataset_filtering_lineage.csv") os.makedirs(config.output_dir, exist_ok=True) lineage_df.to_csv(lineage_path, index=False) logger.info(f"Dataset lineage saved to {lineage_path}. Filtered {len(lineage_df)} samples.") all_available_speakers = sorted(list(speaker_to_files.keys())) if split in ["train", "val"]: # If we only have one speaker, we MUST split at the file level instead of the speaker level if len(all_available_speakers) <= 1: logger.info("Only one speaker detected. Splitting at file level.") all_files_to_split = [] for spk_id in all_available_speakers: all_files_to_split.extend(speaker_to_files[spk_id]) random.seed(config.validation_seed) random.shuffle(all_files_to_split) n_val = max(1, int(len(all_files_to_split) * config.validation_split)) if split == "train": self.files = all_files_to_split[:-n_val] logger.info(f"Training dataset: {len(self.files)} files (Single Speaker Mode).") else: # val self.files = all_files_to_split[-n_val:] logger.info(f"Validation dataset: {len(self.files)} files (Single Speaker Mode).") else: # Split speakers instead of files random.seed(config.validation_seed) random.shuffle(all_available_speakers) n_val_spk = max(1, int(len(all_available_speakers) * config.validation_split)) val_speakers = set(all_available_speakers[-n_val_spk:]) train_speakers = set(all_available_speakers[:-n_val_spk]) self.files = [] if split == "train": for spk_id in train_speakers: self.files.extend(speaker_to_files[spk_id]) logger.info(f"Training dataset: {len(self.files)} files from {len(train_speakers)} speakers.") else: # val for spk_id in val_speakers: self.files.extend(speaker_to_files[spk_id]) logger.info(f"Validation dataset: {len(self.files)} files from {len(val_speakers)} speakers.") else: # all self.files = [] for spk_id in all_available_speakers: self.files.extend(speaker_to_files[spk_id]) logger.info(f"Dataset loaded: {len(self.files)} files from {len(all_available_speakers)} speakers.") except Exception as e: logger.error(f"Error during speaker-aware split: {e}. Falling back to random file split.") # Fallback to random file split if something goes wrong with attribution if split in ["train", "val"]: random.seed(config.validation_seed) random.shuffle(all_files) n_val = max(1, int(len(all_files) * config.validation_split)) if split == "train": self.files = all_files[:-n_val] else: self.files = all_files[-n_val:] else: self.files = all_files self.sot_token = config.start_text_token self.eot_token = config.stop_text_token def __len__(self): return len(self.files) def __getitem__(self, idx): try: filename = self.files[idx] pt_path = os.path.join(self.preprocessed_dir, filename) data = torch.load(pt_path) text_tokens = data["text_tokens"] if text_tokens.size(0) > self.cfg.max_text_len - 2: text_tokens = text_tokens[:self.cfg.max_text_len - 2] sot = torch.tensor([self.sot_token], dtype=torch.long) eot = torch.tensor([self.eot_token], dtype=torch.long) text_tokens = torch.cat([sot, text_tokens, eot]) # 2. Speech Tokens speech_tokens = data["speech_tokens"] if speech_tokens.size(0) > self.cfg.max_speech_len: speech_tokens = speech_tokens[:self.cfg.max_speech_len] return { "text_tokens": text_tokens, "speech_tokens": speech_tokens, "speaker_emb": data["speaker_emb"], "prompt_tokens": data["prompt_tokens"] } except Exception as e: logger.error(f"Error loading {filename}: {e}") return None def data_collator(batch): batch = [item for item in batch if item is not None] if not batch: return {} # Padding text_tokens = pad_sequence([x["text_tokens"] for x in batch], batch_first=True, padding_value=0) speech_tokens = pad_sequence([x["speech_tokens"] for x in batch], batch_first=True, padding_value=0) prompt_tokens = pad_sequence([x["prompt_tokens"] for x in batch], batch_first=True, padding_value=0) speaker_embs = torch.stack([x["speaker_emb"] for x in batch]) # Lengths (Required for masking) text_lens = torch.tensor([len(x["text_tokens"]) for x in batch], dtype=torch.long) speech_lens = torch.tensor([len(x["speech_tokens"]) for x in batch], dtype=torch.long) return { "text_tokens": text_tokens, "text_token_lens": text_lens, "speech_tokens": speech_tokens, "speech_token_lens": speech_lens, "speaker_emb": speaker_embs, "prompt_tokens": prompt_tokens }