|
|
import os |
|
|
import random |
|
|
import torch |
|
|
import pandas as pd |
|
|
import glob |
|
|
from torch.utils.data import Dataset |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
|
from src.utils import setup_logger |
|
|
|
|
|
|
|
|
logger = setup_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
class ChatterboxDataset(Dataset): |
|
|
|
|
|
def __init__(self, config, split="train"): |
|
|
""" |
|
|
Args: |
|
|
config: Training configuration |
|
|
split: "train", "val", or "all" (no split) |
|
|
""" |
|
|
self.cfg = config |
|
|
self.preprocessed_dir = config.preprocessed_dir |
|
|
self.split = split |
|
|
|
|
|
|
|
|
if not os.path.exists(self.preprocessed_dir): |
|
|
raise FileNotFoundError(f"Preprocessing folder not found: {self.preprocessed_dir}.") |
|
|
|
|
|
pattern = os.path.join(self.preprocessed_dir, "**", "*.pt") |
|
|
all_files_full = glob.glob(pattern, recursive=True) |
|
|
|
|
|
all_files = sorted([os.path.normpath(os.path.relpath(f, self.preprocessed_dir)) for f in all_files_full]) |
|
|
|
|
|
if len(all_files) == 0: |
|
|
raise RuntimeError(f"There are no .pt files in the folder (including subdirectories): {self.preprocessed_dir}") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
meta = pd.read_csv(config.csv_path, sep="|", header=None, quoting=3) |
|
|
|
|
|
attr = pd.read_csv(config.attribution_path) |
|
|
|
|
|
|
|
|
|
|
|
file_to_speaker = {} |
|
|
file_to_meta = {} |
|
|
|
|
|
for i in range(len(meta)): |
|
|
wav_filename = str(meta.iloc[i, 0]) |
|
|
|
|
|
|
|
|
pt_filename = wav_filename |
|
|
if pt_filename.endswith(".wav"): |
|
|
pt_filename = pt_filename[:-4] + ".pt" |
|
|
elif not pt_filename.endswith(".pt"): |
|
|
pt_filename += ".pt" |
|
|
|
|
|
|
|
|
pt_filename = os.path.normpath(pt_filename) |
|
|
|
|
|
speaker_id = str(attr.iloc[i]["speaker_id"]) |
|
|
file_to_speaker[pt_filename] = speaker_id |
|
|
|
|
|
|
|
|
file_to_meta[pt_filename] = { |
|
|
"speaker_id": speaker_id, |
|
|
"duration": float(attr.iloc[i].get("duration", 0)), |
|
|
"snr": float(attr.iloc[i].get("snr", 0)) |
|
|
} |
|
|
|
|
|
|
|
|
ood_speakers = set(getattr(config, "ood_speakers", [])) |
|
|
min_duration = getattr(config, "min_training_duration", 4.0) |
|
|
min_snr = getattr(config, "min_training_snr", 20.0) |
|
|
max_snr = getattr(config, "max_training_snr", 100.0) |
|
|
|
|
|
lineage_data = [] |
|
|
|
|
|
|
|
|
speaker_to_files = {} |
|
|
for f in all_files: |
|
|
meta_info = file_to_meta.get(f) |
|
|
if meta_info is None: |
|
|
continue |
|
|
|
|
|
spk_id = meta_info["speaker_id"] |
|
|
duration = meta_info["duration"] |
|
|
snr = meta_info["snr"] |
|
|
|
|
|
reason = None |
|
|
if spk_id in ood_speakers: |
|
|
reason = "OOD_SPEAKER" |
|
|
elif duration < min_duration: |
|
|
reason = "LOW_DURATION" |
|
|
elif snr < min_snr: |
|
|
reason = "LOW_SNR" |
|
|
elif snr > max_snr: |
|
|
reason = "HIGH_SNR" |
|
|
|
|
|
if reason: |
|
|
lineage_data.append({ |
|
|
"file": f, |
|
|
"speaker_id": spk_id, |
|
|
"duration": duration, |
|
|
"snr": snr, |
|
|
"reason": reason |
|
|
}) |
|
|
continue |
|
|
|
|
|
if spk_id not in speaker_to_files: |
|
|
speaker_to_files[spk_id] = [] |
|
|
speaker_to_files[spk_id].append(f) |
|
|
|
|
|
|
|
|
if self.split == "train": |
|
|
lineage_df = pd.DataFrame(lineage_data) |
|
|
lineage_path = os.path.join(config.output_dir, "dataset_filtering_lineage.csv") |
|
|
os.makedirs(config.output_dir, exist_ok=True) |
|
|
lineage_df.to_csv(lineage_path, index=False) |
|
|
logger.info(f"Dataset lineage saved to {lineage_path}. Filtered {len(lineage_df)} samples.") |
|
|
|
|
|
all_available_speakers = sorted(list(speaker_to_files.keys())) |
|
|
|
|
|
if split in ["train", "val"]: |
|
|
|
|
|
if len(all_available_speakers) <= 1: |
|
|
logger.info("Only one speaker detected. Splitting at file level.") |
|
|
all_files_to_split = [] |
|
|
for spk_id in all_available_speakers: |
|
|
all_files_to_split.extend(speaker_to_files[spk_id]) |
|
|
|
|
|
random.seed(config.validation_seed) |
|
|
random.shuffle(all_files_to_split) |
|
|
|
|
|
n_val = max(1, int(len(all_files_to_split) * config.validation_split)) |
|
|
if split == "train": |
|
|
self.files = all_files_to_split[:-n_val] |
|
|
logger.info(f"Training dataset: {len(self.files)} files (Single Speaker Mode).") |
|
|
else: |
|
|
self.files = all_files_to_split[-n_val:] |
|
|
logger.info(f"Validation dataset: {len(self.files)} files (Single Speaker Mode).") |
|
|
else: |
|
|
|
|
|
random.seed(config.validation_seed) |
|
|
random.shuffle(all_available_speakers) |
|
|
|
|
|
n_val_spk = max(1, int(len(all_available_speakers) * config.validation_split)) |
|
|
val_speakers = set(all_available_speakers[-n_val_spk:]) |
|
|
train_speakers = set(all_available_speakers[:-n_val_spk]) |
|
|
|
|
|
self.files = [] |
|
|
if split == "train": |
|
|
for spk_id in train_speakers: |
|
|
self.files.extend(speaker_to_files[spk_id]) |
|
|
logger.info(f"Training dataset: {len(self.files)} files from {len(train_speakers)} speakers.") |
|
|
else: |
|
|
for spk_id in val_speakers: |
|
|
self.files.extend(speaker_to_files[spk_id]) |
|
|
logger.info(f"Validation dataset: {len(self.files)} files from {len(val_speakers)} speakers.") |
|
|
else: |
|
|
self.files = [] |
|
|
for spk_id in all_available_speakers: |
|
|
self.files.extend(speaker_to_files[spk_id]) |
|
|
logger.info(f"Dataset loaded: {len(self.files)} files from {len(all_available_speakers)} speakers.") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during speaker-aware split: {e}. Falling back to random file split.") |
|
|
|
|
|
if split in ["train", "val"]: |
|
|
random.seed(config.validation_seed) |
|
|
random.shuffle(all_files) |
|
|
n_val = max(1, int(len(all_files) * config.validation_split)) |
|
|
if split == "train": |
|
|
self.files = all_files[:-n_val] |
|
|
else: |
|
|
self.files = all_files[-n_val:] |
|
|
else: |
|
|
self.files = all_files |
|
|
|
|
|
self.sot_token = config.start_text_token |
|
|
self.eot_token = config.stop_text_token |
|
|
|
|
|
|
|
|
def __len__(self): |
|
|
return len(self.files) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
try: |
|
|
|
|
|
filename = self.files[idx] |
|
|
|
|
|
pt_path = os.path.join(self.preprocessed_dir, filename) |
|
|
|
|
|
data = torch.load(pt_path) |
|
|
|
|
|
|
|
|
text_tokens = data["text_tokens"] |
|
|
if text_tokens.size(0) > self.cfg.max_text_len - 2: |
|
|
text_tokens = text_tokens[:self.cfg.max_text_len - 2] |
|
|
|
|
|
sot = torch.tensor([self.sot_token], dtype=torch.long) |
|
|
eot = torch.tensor([self.eot_token], dtype=torch.long) |
|
|
text_tokens = torch.cat([sot, text_tokens, eot]) |
|
|
|
|
|
|
|
|
speech_tokens = data["speech_tokens"] |
|
|
if speech_tokens.size(0) > self.cfg.max_speech_len: |
|
|
speech_tokens = speech_tokens[:self.cfg.max_speech_len] |
|
|
|
|
|
return { |
|
|
"text_tokens": text_tokens, |
|
|
"speech_tokens": speech_tokens, |
|
|
"speaker_emb": data["speaker_emb"], |
|
|
"prompt_tokens": data["prompt_tokens"] |
|
|
} |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading {filename}: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def data_collator(batch): |
|
|
|
|
|
batch = [item for item in batch if item is not None] |
|
|
if not batch: |
|
|
return {} |
|
|
|
|
|
|
|
|
text_tokens = pad_sequence([x["text_tokens"] for x in batch], batch_first=True, padding_value=0) |
|
|
speech_tokens = pad_sequence([x["speech_tokens"] for x in batch], batch_first=True, padding_value=0) |
|
|
prompt_tokens = pad_sequence([x["prompt_tokens"] for x in batch], batch_first=True, padding_value=0) |
|
|
|
|
|
speaker_embs = torch.stack([x["speaker_emb"] for x in batch]) |
|
|
|
|
|
|
|
|
text_lens = torch.tensor([len(x["text_tokens"]) for x in batch], dtype=torch.long) |
|
|
speech_lens = torch.tensor([len(x["speech_tokens"]) for x in batch], dtype=torch.long) |
|
|
|
|
|
|
|
|
return { |
|
|
"text_tokens": text_tokens, |
|
|
"text_token_lens": text_lens, |
|
|
"speech_tokens": speech_tokens, |
|
|
"speech_token_lens": speech_lens, |
|
|
"speaker_emb": speaker_embs, |
|
|
"prompt_tokens": prompt_tokens |
|
|
} |