|
|
import math |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
|
|
|
import argbind |
|
|
import torch |
|
|
from datasets import Audio, Dataset, DatasetDict, load_dataset |
|
|
from torch.utils.data import Dataset as TorchDataset |
|
|
|
|
|
from ..model.voxcpm import VoxCPMConfig |
|
|
from ..modules.audiovae import AudioVAE |
|
|
from .packers import AudioFeatureProcessingPacker |
|
|
|
|
|
|
|
|
DEFAULT_TEXT_COLUMN = "text" |
|
|
DEFAULT_AUDIO_COLUMN = "audio" |
|
|
DEFAULT_ID_COLUMN = "dataset_id" |
|
|
|
|
|
|
|
|
@argbind.bind() |
|
|
def load_audio_text_datasets( |
|
|
train_manifest: str, |
|
|
val_manifest: str = "", |
|
|
text_column: str = DEFAULT_TEXT_COLUMN, |
|
|
audio_column: str = DEFAULT_AUDIO_COLUMN, |
|
|
dataset_id_column: str = DEFAULT_ID_COLUMN, |
|
|
sample_rate: int = 16_000, |
|
|
num_proc: int = 1, |
|
|
) -> Tuple[Dataset, Optional[Dataset]]: |
|
|
data_files = {"train": train_manifest} |
|
|
if val_manifest: |
|
|
data_files["validation"] = val_manifest |
|
|
|
|
|
dataset_dict: DatasetDict = load_dataset("json", data_files=data_files) |
|
|
|
|
|
def prepare(ds: Dataset) -> Dataset: |
|
|
if audio_column not in ds.column_names: |
|
|
raise ValueError(f"Expected '{audio_column}' column in manifest.") |
|
|
|
|
|
|
|
|
|
|
|
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate)) |
|
|
if audio_column != DEFAULT_AUDIO_COLUMN: |
|
|
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN) |
|
|
if text_column != DEFAULT_TEXT_COLUMN: |
|
|
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN) |
|
|
if dataset_id_column and dataset_id_column in ds.column_names: |
|
|
if dataset_id_column != DEFAULT_ID_COLUMN: |
|
|
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN) |
|
|
else: |
|
|
ds = ds.add_column(DEFAULT_ID_COLUMN, [0] * len(ds)) |
|
|
return ds |
|
|
|
|
|
train_ds = prepare(dataset_dict["train"]) |
|
|
val_ds = prepare(dataset_dict["validation"]) if "validation" in dataset_dict else None |
|
|
return train_ds, val_ds |
|
|
|
|
|
|
|
|
def compute_sample_lengths( |
|
|
ds: Dataset, |
|
|
audio_vae_fps: int = 25, |
|
|
patch_size: int = 1, |
|
|
) -> List[int]: |
|
|
""" |
|
|
预估每个样本经过 packer 之后的大致序列长度(text+audio),用于过滤超长样本。 |
|
|
|
|
|
逻辑与 AudioFeatureProcessingPacker / AudioVAE 一致: |
|
|
- 文本长度: len(text_ids) |
|
|
- 音频长度: |
|
|
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae |
|
|
t_seq = ceil(t_vae / patch_size) |
|
|
- 序列总长约为: text_len + t_seq + 2 |
|
|
""" |
|
|
lengths: List[int] = [] |
|
|
|
|
|
has_duration = "duration" in ds.column_names |
|
|
|
|
|
for i in range(len(ds)): |
|
|
item = ds[i] |
|
|
text_len = len(item["text_ids"]) |
|
|
|
|
|
|
|
|
if has_duration: |
|
|
duration = float(item["duration"]) |
|
|
else: |
|
|
audio = item[DEFAULT_AUDIO_COLUMN] |
|
|
duration = len(audio["array"]) / float(audio["sampling_rate"]) |
|
|
|
|
|
t_vae = math.ceil(duration * audio_vae_fps) |
|
|
t_seq = math.ceil(t_vae / patch_size) |
|
|
|
|
|
total_len = text_len + t_seq + 2 |
|
|
lengths.append(total_len) |
|
|
|
|
|
return lengths |
|
|
|
|
|
|
|
|
class HFVoxCPMDataset(TorchDataset): |
|
|
""" |
|
|
Thin wrapper around a tokenized HuggingFace dataset that returns |
|
|
PyTorch-friendly samples. |
|
|
""" |
|
|
|
|
|
def __init__(self, dataset: Dataset): |
|
|
self.dataset = dataset |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.dataset) |
|
|
|
|
|
def __getitem__(self, idx: int): |
|
|
item = self.dataset[idx] |
|
|
audio = item[DEFAULT_AUDIO_COLUMN] |
|
|
return { |
|
|
"text_ids": item["text_ids"], |
|
|
"audio_array": audio["array"], |
|
|
"audio_sampling_rate": audio["sampling_rate"], |
|
|
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0), |
|
|
"is_prompt": item.get("is_prompt", False), |
|
|
} |
|
|
|
|
|
@staticmethod |
|
|
def pad_sequences(seqs: List[torch.Tensor], pad_value: float): |
|
|
if not seqs: |
|
|
return torch.empty(0) |
|
|
max_len = max(seq.shape[0] for seq in seqs) |
|
|
padded = [] |
|
|
for seq in seqs: |
|
|
if seq.shape[0] < max_len: |
|
|
pad_width = (0, max_len - seq.shape[0]) |
|
|
seq = torch.nn.functional.pad(seq, pad_width, value=pad_value) |
|
|
padded.append(seq) |
|
|
return torch.stack(padded) |
|
|
|
|
|
@classmethod |
|
|
def collate_fn(cls, batch: List[Dict]): |
|
|
text_tensors = [torch.tensor(sample["text_ids"], dtype=torch.int32) for sample in batch] |
|
|
audio_tensors = [torch.tensor(sample["audio_array"], dtype=torch.float32) for sample in batch] |
|
|
dataset_ids = torch.tensor([sample["dataset_id"] for sample in batch], dtype=torch.int32) |
|
|
is_prompts = [bool(sample.get("is_prompt", False)) for sample in batch] |
|
|
|
|
|
text_padded = cls.pad_sequences(text_tensors, pad_value=-100) |
|
|
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0) |
|
|
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32) |
|
|
|
|
|
return { |
|
|
"text_tokens": text_padded, |
|
|
"audio_tokens": audio_padded, |
|
|
"task_ids": task_ids, |
|
|
"dataset_ids": dataset_ids, |
|
|
"is_prompts": is_prompts, |
|
|
} |
|
|
|
|
|
|
|
|
class BatchProcessor: |
|
|
""" |
|
|
Wraps ``AudioFeatureProcessingPacker`` so the training loop can mirror |
|
|
the minicpm-audio mechanics. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
config: VoxCPMConfig, |
|
|
audio_vae: AudioVAE, |
|
|
dataset_cnt: int, |
|
|
device: torch.device, |
|
|
): |
|
|
self.device = device |
|
|
self.dataset_cnt = dataset_cnt |
|
|
self.packer = AudioFeatureProcessingPacker( |
|
|
dataset_cnt=dataset_cnt, |
|
|
max_len=config.max_length, |
|
|
patch_size=config.patch_size, |
|
|
feat_dim=config.feat_dim, |
|
|
audio_vae=audio_vae, |
|
|
) |
|
|
|
|
|
def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
audio_tokens = batch["audio_tokens"].to(self.device) |
|
|
text_tokens = batch["text_tokens"].to(self.device) |
|
|
task_ids = batch["task_ids"].to(self.device) |
|
|
dataset_ids = batch["dataset_ids"].to(self.device) |
|
|
|
|
|
packed = self.packer( |
|
|
audio_tokens=audio_tokens, |
|
|
text_tokens=text_tokens, |
|
|
task_ids=task_ids, |
|
|
dataset_ids=dataset_ids, |
|
|
is_prompts=batch["is_prompts"], |
|
|
) |
|
|
return packed |
|
|
|
|
|
|
|
|
def build_dataloader( |
|
|
hf_dataset: Dataset, |
|
|
*, |
|
|
accelerator, |
|
|
batch_size: int, |
|
|
num_workers: int, |
|
|
drop_last: bool = False, |
|
|
) -> torch.utils.data.DataLoader: |
|
|
torch_dataset = HFVoxCPMDataset(hf_dataset) |
|
|
|
|
|
return accelerator.prepare_dataloader( |
|
|
torch_dataset, |
|
|
batch_size=batch_size, |
|
|
num_workers=num_workers, |
|
|
shuffle=True, |
|
|
collate_fn=HFVoxCPMDataset.collate_fn, |
|
|
drop_last=drop_last, |
|
|
) |
|
|
|
|
|
|