Chatterbox-Finnish / src /dataset.py
RASMUS's picture
Upload Finnish Chatterbox model
67ea4ca verified
import os
import random
import torch
import pandas as pd
import glob
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from src.utils import setup_logger
logger = setup_logger(__name__)
class ChatterboxDataset(Dataset):
def __init__(self, config, split="train"):
"""
Args:
config: Training configuration
split: "train", "val", or "all" (no split)
"""
self.cfg = config
self.preprocessed_dir = config.preprocessed_dir
self.split = split
# List all .pt files recursively
if not os.path.exists(self.preprocessed_dir):
raise FileNotFoundError(f"Preprocessing folder not found: {self.preprocessed_dir}.")
pattern = os.path.join(self.preprocessed_dir, "**", "*.pt")
all_files_full = glob.glob(pattern, recursive=True)
# Store relative paths to the preprocessed directory, normalized for consistent matching
all_files = sorted([os.path.normpath(os.path.relpath(f, self.preprocessed_dir)) for f in all_files_full])
if len(all_files) == 0:
raise RuntimeError(f"There are no .pt files in the folder (including subdirectories): {self.preprocessed_dir}")
# --- Speaker-Aware Splitting & Filtering Logic ---
try:
# 1. Load mappings
# metadata.csv: wav_path|raw_text|norm_text
meta = pd.read_csv(config.csv_path, sep="|", header=None, quoting=3)
# attribution: audio_file,resolved_path,text,speaker_id,...
attr = pd.read_csv(config.attribution_path)
# 2. Build filename -> speaker_id mapping and collect metadata for filtering
# We know meta and attr are in the same order
file_to_speaker = {}
file_to_meta = {} # For traceability
for i in range(len(meta)):
wav_filename = str(meta.iloc[i, 0])
# Convert wav filename to pt filename while preserving structure
pt_filename = wav_filename
if pt_filename.endswith(".wav"):
pt_filename = pt_filename[:-4] + ".pt"
elif not pt_filename.endswith(".pt"):
pt_filename += ".pt"
# Normalize path for consistent matching
pt_filename = os.path.normpath(pt_filename)
speaker_id = str(attr.iloc[i]["speaker_id"])
file_to_speaker[pt_filename] = speaker_id
# Store duration and SNR for filtering logic
file_to_meta[pt_filename] = {
"speaker_id": speaker_id,
"duration": float(attr.iloc[i].get("duration", 0)),
"snr": float(attr.iloc[i].get("snr", 0))
}
# 3. Filter OOD speakers and low-quality samples
ood_speakers = set(getattr(config, "ood_speakers", []))
min_duration = getattr(config, "min_training_duration", 4.0)
min_snr = getattr(config, "min_training_snr", 20.0)
max_snr = getattr(config, "max_training_snr", 100.0)
lineage_data = []
# Group files by speaker_id
speaker_to_files = {}
for f in all_files:
meta_info = file_to_meta.get(f)
if meta_info is None:
continue
spk_id = meta_info["speaker_id"]
duration = meta_info["duration"]
snr = meta_info["snr"]
reason = None
if spk_id in ood_speakers:
reason = "OOD_SPEAKER"
elif duration < min_duration:
reason = "LOW_DURATION"
elif snr < min_snr:
reason = "LOW_SNR"
elif snr > max_snr:
reason = "HIGH_SNR"
if reason:
lineage_data.append({
"file": f,
"speaker_id": spk_id,
"duration": duration,
"snr": snr,
"reason": reason
})
continue # Exclude from training/validation
if spk_id not in speaker_to_files:
speaker_to_files[spk_id] = []
speaker_to_files[spk_id].append(f)
# Save lineage if this is the first initialization (e.g. for "train" split)
if self.split == "train":
lineage_df = pd.DataFrame(lineage_data)
lineage_path = os.path.join(config.output_dir, "dataset_filtering_lineage.csv")
os.makedirs(config.output_dir, exist_ok=True)
lineage_df.to_csv(lineage_path, index=False)
logger.info(f"Dataset lineage saved to {lineage_path}. Filtered {len(lineage_df)} samples.")
all_available_speakers = sorted(list(speaker_to_files.keys()))
if split in ["train", "val"]:
# If we only have one speaker, we MUST split at the file level instead of the speaker level
if len(all_available_speakers) <= 1:
logger.info("Only one speaker detected. Splitting at file level.")
all_files_to_split = []
for spk_id in all_available_speakers:
all_files_to_split.extend(speaker_to_files[spk_id])
random.seed(config.validation_seed)
random.shuffle(all_files_to_split)
n_val = max(1, int(len(all_files_to_split) * config.validation_split))
if split == "train":
self.files = all_files_to_split[:-n_val]
logger.info(f"Training dataset: {len(self.files)} files (Single Speaker Mode).")
else: # val
self.files = all_files_to_split[-n_val:]
logger.info(f"Validation dataset: {len(self.files)} files (Single Speaker Mode).")
else:
# Split speakers instead of files
random.seed(config.validation_seed)
random.shuffle(all_available_speakers)
n_val_spk = max(1, int(len(all_available_speakers) * config.validation_split))
val_speakers = set(all_available_speakers[-n_val_spk:])
train_speakers = set(all_available_speakers[:-n_val_spk])
self.files = []
if split == "train":
for spk_id in train_speakers:
self.files.extend(speaker_to_files[spk_id])
logger.info(f"Training dataset: {len(self.files)} files from {len(train_speakers)} speakers.")
else: # val
for spk_id in val_speakers:
self.files.extend(speaker_to_files[spk_id])
logger.info(f"Validation dataset: {len(self.files)} files from {len(val_speakers)} speakers.")
else: # all
self.files = []
for spk_id in all_available_speakers:
self.files.extend(speaker_to_files[spk_id])
logger.info(f"Dataset loaded: {len(self.files)} files from {len(all_available_speakers)} speakers.")
except Exception as e:
logger.error(f"Error during speaker-aware split: {e}. Falling back to random file split.")
# Fallback to random file split if something goes wrong with attribution
if split in ["train", "val"]:
random.seed(config.validation_seed)
random.shuffle(all_files)
n_val = max(1, int(len(all_files) * config.validation_split))
if split == "train":
self.files = all_files[:-n_val]
else:
self.files = all_files[-n_val:]
else:
self.files = all_files
self.sot_token = config.start_text_token
self.eot_token = config.stop_text_token
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
try:
filename = self.files[idx]
pt_path = os.path.join(self.preprocessed_dir, filename)
data = torch.load(pt_path)
text_tokens = data["text_tokens"]
if text_tokens.size(0) > self.cfg.max_text_len - 2:
text_tokens = text_tokens[:self.cfg.max_text_len - 2]
sot = torch.tensor([self.sot_token], dtype=torch.long)
eot = torch.tensor([self.eot_token], dtype=torch.long)
text_tokens = torch.cat([sot, text_tokens, eot])
# 2. Speech Tokens
speech_tokens = data["speech_tokens"]
if speech_tokens.size(0) > self.cfg.max_speech_len:
speech_tokens = speech_tokens[:self.cfg.max_speech_len]
return {
"text_tokens": text_tokens,
"speech_tokens": speech_tokens,
"speaker_emb": data["speaker_emb"],
"prompt_tokens": data["prompt_tokens"]
}
except Exception as e:
logger.error(f"Error loading {filename}: {e}")
return None
def data_collator(batch):
batch = [item for item in batch if item is not None]
if not batch:
return {}
# Padding
text_tokens = pad_sequence([x["text_tokens"] for x in batch], batch_first=True, padding_value=0)
speech_tokens = pad_sequence([x["speech_tokens"] for x in batch], batch_first=True, padding_value=0)
prompt_tokens = pad_sequence([x["prompt_tokens"] for x in batch], batch_first=True, padding_value=0)
speaker_embs = torch.stack([x["speaker_emb"] for x in batch])
# Lengths (Required for masking)
text_lens = torch.tensor([len(x["text_tokens"]) for x in batch], dtype=torch.long)
speech_lens = torch.tensor([len(x["speech_tokens"]) for x in batch], dtype=torch.long)
return {
"text_tokens": text_tokens,
"text_token_lens": text_lens,
"speech_tokens": speech_tokens,
"speech_token_lens": speech_lens,
"speaker_emb": speaker_embs,
"prompt_tokens": prompt_tokens
}