NAVA-Text-to-Video / nava_src /data /dataset_train.py
prithivMLmods's picture
user --add src
0ff8d3d verified
import json, os, random, torch
import threading
import queue
import time
import copy
from re import T
from PIL import Image, ImageFile
from collections import defaultdict
import numpy as np
import io
from concurrent.futures import ThreadPoolExecutor, as_completed, wait, FIRST_COMPLETED
import traceback
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, IterableDataset, DataLoader
from torchvision import transforms
from dataclasses import dataclass
from typing import Iterator, Tuple, Optional, List, Dict
from functools import partial
import re
import math
from collections import deque
from nava_src.utils.common import compute_tgt_ratio
import torch.nn.functional as F
import torch.distributed as dist
ImageFile.LOAD_TRUNCATED_IMAGES = True
class PerfLogger:
"""Thread-safe performance logger for diagnosing multi-node training slowdowns."""
def __init__(self, worker_id, print_interval=60, tag="DATA_PERF"):
self.worker_id = worker_id
self.print_interval = print_interval
self.tag = tag
self._lock = threading.Lock()
self._stats = defaultdict(lambda: {"count": 0, "total": 0.0, "max": 0.0, "min": float("inf")})
self._last_print = time.monotonic()
def record(self, key, elapsed):
with self._lock:
s = self._stats[key]
s["count"] += 1
s["total"] += elapsed
if elapsed > s["max"]:
s["max"] = elapsed
if elapsed < s["min"]:
s["min"] = elapsed
def maybe_print(self, force=False):
now = time.monotonic()
with self._lock:
if not force and (now - self._last_print) < self.print_interval:
return
if not self._stats:
return
lines = [f"[{self.tag} {self.worker_id}] === Perf Summary ==="]
for key in sorted(self._stats.keys()):
s = self._stats[key]
cnt = s["count"]
avg = s["total"] / cnt if cnt > 0 else 0
lines.append(
f" {key}: avg={avg:.4f}s max={s['max']:.4f}s min={s['min']:.4f}s total={s['total']:.2f}s cnt={cnt}"
)
print("\n".join(lines), flush=True)
# reset
self._stats.clear()
self._last_print = now
def filter_video_descriptions(text):
'''
remove caption templete for videogen
'''
filtered_text = text.replace('\n', '').replace('概述:', '').replace('细节:', '').replace('背景:', '').strip()
filtered_text = re.sub(f'这段视频[^, 。]*(了|是|\.|,)', "", filtered_text)
filtered_text = re.sub(f'这段画面[^, 。]*(了|是|\.|,)', "", filtered_text)
filtered_text = re.sub(f'视频[^, 。]*(了|是|\.|,)', "", filtered_text)
filtered_text = re.sub(f'画面[^, 。]*(了|是|\.|,)', "", filtered_text)
filtered_text = re.sub(r'视频中,', '', filtered_text)
filtered_text = re.sub(r'视频中', '', filtered_text)
filtered_text = re.sub(r'画面中,', '', filtered_text)
filtered_text = re.sub(r'画面中', '', filtered_text)
return filtered_text
@dataclass
class DistInfo:
world_rank: int = 0 # 当前进程编号
world_size: int = 1 # 总进程
class _LazyJsonlSource:
"""
惰性按行读取一个JSONL文件。
支持:从指定偏移跳过、is_cycle(读到末尾后循环)、skip_ratio(抽稀)、返回坏样本数。
结果:(sample_dict, item_id, nb_bad)
"""
def __init__(
self,
path_list,
start_skip=0,
is_cycle=True,
skip_ratio=0.0,
shard_id=0,
num_shards=1,
file_idx=0,
data_file_divisor=1,
):
self.path_list = path_list
self.start_skip = max(0, start_skip) # 至少跳过0行
self.is_cycle = is_cycle
self.skip_ratio = skip_ratio
self._opened_once = False
self._line_id = 0
self.data_file_divisor = data_file_divisor
self.current_part_id = shard_id % self.data_file_divisor
assert num_shards % self.data_file_divisor == 0
self.num_shard_per_part = num_shards // self.data_file_divisor
self.shard_id_in_part = shard_id // self.data_file_divisor
while file_idx % self.data_file_divisor != self.current_part_id:
file_idx += 1
# 当前文件下标以及当前句柄
self._file_idx = file_idx % len(self.path_list)
self._f = None
def _open(self):
# 先关闭旧文件
if self._f is not None:
self._f.close()
path = self.path_list[self._file_idx]
# 只读模式打开
self._f = open(path, "r", encoding="utf-8")
self._opened_once = True
# 重置行
self._line_id = 0
def _line_iter(self):
if not self._opened_once:
self._open()
while True:
line = self._f.readline()
if line:
yield line
else:
self._f.close()
self._f = None
self._file_idx += self.data_file_divisor
if self._file_idx >= len(self.path_list):
if not self.is_cycle:
return
self._file_idx = self.current_part_id
self._open()
def __iter__(self) -> Iterator[Tuple[dict, int, int]]:
# 记录跳过的有效样本
skipped = 0
sample_idx = 0 # 统计“所有合法样本”的全局计数,不区分 shard
# 跳过start_skip行,记录坏行
# TODO: @songyi05: 暂时不支持断点续训,等待后续修复
line_iter = self._line_iter()
while skipped < self.start_skip:
try:
line = next(line_iter)
except StopIteration:
return
self._line_id += 1
skipped += 1
sample_idx += 1
# 主循环
for line in line_iter:
if not line:
if not self.is_cycle:
return
continue
self._line_id += 1
if (sample_idx % self.num_shard_per_part) != self.shard_id_in_part:
sample_idx += 1
continue
# 如果抽稀
if self.skip_ratio > 0 and random.random() < self.skip_ratio:
sample_idx += 1
continue
try:
obj = json.loads(line)
except Exception:
continue
yield obj, self._file_idx, self._line_id
sample_idx += 1
# TODO: 如何加速不同 modal 读取
# 在文生图任务时,设计 prompt template 以对齐推理
# TODO: @hujiahao03 后续解耦多模态数据集读取
# TODO: @hujiahao03 未来测试使用多模态 adaptive layernorm 处理不同模态 token; 实现困难
class AudioVideoDataset(IterableDataset):
"""
每个样本:可能只有文本,也可能有 (caption, image),按 config.caption_first_ratio 决定顺序。
jsonl 格式:{"text": "..."} 或 {"text": "...", "image": "path/to.jpg"}
优化版:惰性读取 + 严格按源加权 + 逐源历史
"""
def __init__(
self,
jsonl_or_src_list,
batch_size=4, # NOTE: 仅在宽高比分桶时有效
io_workers=16, # <--- IO线程数
queue_size=32, # 队列深度,决定了能抗多久的抖动 (50个batch)
# 多数据源参数
src_id2ratios=None,
modal_prob={"text_to_audio": 1.0, "text_to_video": 0.0, "text_to_image": 0.0, "text_to_av": 0.0},
transform_config=None,
# 流式数据集参数
workers2history_dict=None,
num_shards=1,
dist_info=None, # 计算全局worker id
random_src=True, # True随机选,False轮询
upper_count=float("inf"),
is_cycle=True,
skip_ratio=0,
is_json=True,
audio_vae_server=None,
image_vae_server=None,
video_vae_server=None,
use_aspect_ratio_buckets=False, # 宽高比分桶
use_length_buckets=False, # 按长度分桶(audio/video/av)
num_length_buckets=10, # 长度分桶数量
is_packing=False, # seq packing
min_audio_duration=0.0,
max_audio_duration=15.0,
tgt_audio_duration=-1,
video_min_frames=1,
video_max_frames=5,
video_tgt_frames=-1,
video_fps=8,
add_spk_emb=False,
spk_emb_prob=0.9,
use_speech_special_token=False,
data_file_divisor=1,
split_wav_mode=False,
audio_tokens_per_sec=31.25,
enable_perf_log=False,
enable_ddp_bucket_sync=False, # DDP分桶同步:所有rank每步从同一个bucket取数据
):
"""
jsonl_or_src_list: json file or list of json file regex prob(0.2) and modal
src_id2ratios: 配比 {"regex": [prob(0.2), modal]}
"""
super().__init__()
self.batch_size = batch_size
# 处理多数据源或单数据源输入
if isinstance(jsonl_or_src_list, list):
# 多数据源模式
self.src_id2pathes = defaultdict(list)
# TODO:@songyi05: 待优化,目前使用naive for循环 原始字符串匹配
for tp in jsonl_or_src_list:
if len(tp) == 1:
jsonl = tp[0]
for key in src_id2ratios.keys():
if key in jsonl:
self.src_id2pathes[key].append(jsonl)
break
elif len(tp) == 2:
name, jsonl = tp
self.src_id2pathes[name].append(jsonl)
else:
assert False
else:
# 单数据源兼容模式
self.src_id2pathes = {"default": [jsonl_or_src_list]}
self.src_id2ratios = src_id2ratios or {k: 1.0 for k in self.src_id2pathes}
self.src_ids = sorted(
self.src_id2pathes.keys()
) # "81000200000381,81000200000382,....."
# 各个数据源权重
self.ratios = [float(self.src_id2ratios[s][0]) for s in self.src_ids]
self.io_workers = io_workers
self.queue_size = queue_size
self.io_pool = ThreadPoolExecutor(max_workers=self.io_workers)
self._producer_running = False
self.fetch_lock = threading.Lock()
self._thread_local = threading.local()
self._local_iter_count = 0
# 目标模态 batch 比例
# 示例:支持三种模态
self.modality_configs = {
"text_to_audio": {
"weight": modal_prob.get("text_to_audio", 1.0),
},
"text_to_video": {
"weight": modal_prob.get("text_to_video", 0.0),
},
"text_to_image": {
"weight": modal_prob.get("text_to_image", 0.0),
},
"text_to_av": {
"weight": modal_prob.get("text_to_av", 0.0)
},
}
# self.io_workers = io_workers
# self.queue_size = queue_size
# === 为每个模态创建一个独立的 Queue ===
self.modality_queues = {}
self.producer_threads = [] # 存放线程句柄
for m in self.modality_configs.keys():
# 过滤掉权重为0的模态
if self.modality_configs[m]['weight'] <= 0:
continue
# 文本队列容量可以大一点
if m == "text_only":
limit = self.queue_size * self.batch_size * 5
else:
limit = self.queue_size * self.batch_size
self.modality_queues[m] = queue.Queue(maxsize=limit)
self.modality_mappers = {
"text_to_audio": 0,
"text_to_video": 1,
"text_to_image": 2,
"text_to_av": 3,
}
# t2i buffer
self.t2a_buffer = deque()
self.t2v_buffer = deque()
self.t2i_buffer = deque()
self.t2av_buffer = deque()
self.modality_buffers_mapper = {
"text_to_audio": self.t2a_buffer,
"text_to_video": self.t2v_buffer,
"text_to_image": self.t2i_buffer,
"text_to_av": self.t2av_buffer,
}
# 宽高比分桶
self.use_aspect_ratio_buckets = use_aspect_ratio_buckets
if self.use_aspect_ratio_buckets:
self.aspect_ratio_buckets_aud_gen = defaultdict(list)
self.aspect_ratio_buckets_vid_gen = defaultdict(list)
self.aspect_ratio_buckets_img_gen = defaultdict(list)
self.aspect_ratio_buckets_av_gen = defaultdict(list)
self.aspect_ratio_buckets_mapper = {
"text_to_audio": self.aspect_ratio_buckets_aud_gen,
"text_to_video": self.aspect_ratio_buckets_vid_gen,
"text_to_image": self.aspect_ratio_buckets_img_gen,
"text_to_av": self.aspect_ratio_buckets_av_gen,
}
self.is_packing = is_packing
# 长度分桶(audio/video/av 按 latent 长度等距分为 N 个桶)
self.use_length_buckets = use_length_buckets
self.num_length_buckets = num_length_buckets
# DDP 分桶同步:需要 use_length_buckets 一起开启
self.enable_ddp_bucket_sync = enable_ddp_bucket_sync and self.use_length_buckets
if self.enable_ddp_bucket_sync:
self._bucket_sync_buf = torch.tensor([-1], dtype=torch.long)
# 每个桶的容量上限:桶满时停止从 Queue 拉数据,让 Queue 背压阻塞 producer
self._per_bucket_cap = self.queue_size * self.batch_size if self.use_length_buckets else 0
# 新增多worker参数
self.num_shards = num_shards
self.dist_info = dist_info or DistInfo(0, 1)
self.random_src = random_src # bool
self.upper_count = upper_count
self.is_cycle = is_cycle # bool
self.skip_ratio = skip_ratio # float
self.is_json = is_json
self.worker_id = None
self.data_file_divisor = data_file_divisor
self.audio_vae_server = audio_vae_server
self.image_vae_server = image_vae_server
self.video_vae_server = video_vae_server
self.min_audio_duration = min_audio_duration
self.max_audio_duration = max_audio_duration
self.tgt_audio_duration = tgt_audio_duration
self.add_spk_emb = add_spk_emb
self.spk_emb_prob = spk_emb_prob
self.use_speech_special_token = use_speech_special_token
self.video_min_frames = video_min_frames
self.video_max_frames = video_max_frames
self.video_tgt_frames = video_tgt_frames
self.video_fps = video_fps
self.split_wav_mode = split_wav_mode
self.audio_tokens_per_sec = audio_tokens_per_sec
print(self.audio_tokens_per_sec)
self.enable_perf_log = enable_perf_log
if self.use_length_buckets:
self.length_buckets_aud_gen = defaultdict(deque)
self.length_buckets_vid_gen = defaultdict(deque)
self.length_buckets_av_gen = defaultdict(deque)
self.length_buckets_mapper = {
"text_to_audio": self.length_buckets_aud_gen,
"text_to_video": self.length_buckets_vid_gen,
"text_to_av": self.length_buckets_av_gen,
}
# 预计算各模态的桶边界(等距划分)
# Audio: latent 长度范围
aud_min_len = math.ceil(self.min_audio_duration * self.audio_tokens_per_sec)
aud_max_len = math.ceil(self.max_audio_duration * self.audio_tokens_per_sec)
self._aud_bucket_boundaries = np.linspace(aud_min_len, aud_max_len, self.num_length_buckets + 1)
# Video / AV: latent 帧数范围,原始帧数->latent帧数: (f-1)//4+1
vid_min_latent = (self.video_min_frames - 1) // 4 + 1
vid_max_latent = (self.video_tgt_frames - 1) // 4 + 1
self._vid_bucket_boundaries = np.linspace(vid_min_latent, vid_max_latent, self.num_length_buckets + 1)
# 历史记录:[num_workers, num_sources]
if workers2history_dict is None:
self.workers2history_dict = {
wid: {src_id: {"file_idx": 0, "line_id": 0} for src_id in self.src_ids}
for wid in range(self.num_shards)
}
elif isinstance(workers2history_dict, list):
self.workers2history_dict = {
wid: {src_id: {"file_idx": 0, "line_id": 0} for src_id in self.src_ids}
for wid in range(self.num_shards)
}
for wid in range(self.num_shards):
row = workers2history_dict[
wid
] # 一个 worker 的所有源数据,例如 [f1,l1,f2,l2,..]
for si, src_id in enumerate(self.src_ids):
f = row[2 * si]
l = row[2 * si + 1]
self.workers2history_dict[wid][src_id]["file_idx"] = f
self.workers2history_dict[wid][src_id]["line_id"] = l
elif isinstance(workers2history_dict, int):
per = workers2history_dict // max(1, self.num_shards)
self.workers2history_dict = {
wid: {
src_id: {"file_idx": 0, "line_id": per} for src_id in self.src_ids
}
for wid in range(self.num_shards)
}
else:
self.workers2history_dict = workers2history_dict
# 对齐当前 src_ids,给“新增的数据源”补上默认历史
for wid in range(self.num_shards):
if wid not in self.workers2history_dict:
self.workers2history_dict[wid] = {}
for src_id in self.src_ids:
if src_id not in self.workers2history_dict[wid]:
self.workers2history_dict[wid][src_id] = {
"file_idx": 0,
"line_id": 0,
}
self._sources = []
self._inited = False
self._perf = None
# print("[DEBUG __init__] src_id2pathes =", self.src_id2pathes)
# print("[DEBUG __init__] src_ids =", getattr(self, "src_ids", None))
def __len__(self):
# 惰性读取模式下,返回upper_count作为长度估计
return int(self.upper_count) * len(self.src_ids)
def set_distributed_info(self, dist_info: DistInfo):
"""设置分布式训练信息"""
self.dist_info = dist_info
return self
def _get_tokenizer(self):
"""
获取当前线程专属的 Tokenizer 副本。
如果没有,就从 self.tok 深拷贝一份。
"""
# 检查当前线程是否已经有了私有副本
if not hasattr(self._thread_local, 'tokenizer'):
self._thread_local.tokenizer = copy.deepcopy(self.tok)
return self._thread_local.tokenizer
# 获取worker的id
def get_current_worker(self):
"""获取当前worker ID,支持分布式训练"""
if self.dist_info is None:
process_index = 0
num_processes = 1
else:
process_index, num_processes = (
self.dist_info.world_rank,
self.dist_info.world_size,
)
# 获取当前的worker信息
worker_info = torch.utils.data.get_worker_info()
num_processes = max(1, num_processes)
if worker_info is not None:
# 进程数*每个进程的子进程=总进程数
assert (
num_processes * worker_info.num_workers == self.num_shards
), f"{num_processes} {worker_info.num_workers} {self.num_shards}"
return process_index * worker_info.num_workers + worker_info.id
# 单进程
else:
assert num_processes == self.num_shards, f"process: {num_processes} shards: {self.num_shards}"
return process_index
# 只进行一次
def _init_sources_once(self):
if self._inited:
return
self.worker_id = self.get_current_worker()
self._perf = PerfLogger(worker_id=f"R{self.dist_info.world_rank}-W{self.worker_id}") if self.enable_perf_log else None
# 设置随机种子
base_seed = 20251111 + self.worker_id * 97
random.seed(base_seed)
torch.manual_seed(base_seed)
np.random.seed(base_seed % (2**32 - 1))
# 为每个数据源创建一个LazyJsonlSource对象,并跳过历史位置
for src_id in self.src_ids:
modal = self.src_id2ratios[src_id][1]
file_idx = self.workers2history_dict[self.worker_id][src_id]["file_idx"]
start_skip = self.workers2history_dict[self.worker_id][src_id]["line_id"]
src = _LazyJsonlSource(
path_list=self.src_id2pathes[src_id],
start_skip=start_skip,
is_cycle=self.is_cycle,
skip_ratio=self.skip_ratio, # 按比例随机跳过某系行
shard_id=self.worker_id,
num_shards=self.num_shards,
file_idx=file_idx,
data_file_divisor=self.data_file_divisor,
)
self._sources.append((src_id, iter(src), modal))
# self._sources 是 [(src_id, iter, modal), ...]
# 构建模态 -> 数据源列表 的映射
self._sources_by_modal = defaultdict(list)
for item in self._sources:
src_id, src_iter, modal = item
self._sources_by_modal[modal].append(item)
# 同时记录每个模态对应的采样权重(如果 self.ratios 存在)
if hasattr(self, "ratios") and self.ratios:
# 假设 self.ratios 与 self._sources 顺序一一对应
self._modal_ratios = defaultdict(list)
self._modal_sources_list = defaultdict(list) # 保持顺序,用于对齐 ratios
for item, weight in zip(self._sources, self.ratios):
modal = item[2]
self._modal_sources_list[modal].append(item)
self._modal_ratios[modal].append(weight)
else:
# 无权重,等概率
self._modal_ratios = None
self._inited = True
def _fetch_raw_jsons(self, needed_modality, count):
items = []
# === 加锁:防止多个 Producer 线程同时操作 generator ===
with self.fetch_lock:
for _ in range(count):
try:
# 1. 检查所需模态是否有可用源,如果没有直接break
# [WARN]: 如果数据模式没有采用cyclic,可能会陷入死锁
if needed_modality not in self._sources_by_modal or not self._sources_by_modal[needed_modality]:
break
# 2. 采样:选择一个 Source (src_id, iterator, modal)
if self.random_src:
# 优先使用带权重的采样 (如果在 __init__ 里初始化了 ratios)
if hasattr(self, '_modal_ratios') and self._modal_ratios and needed_modality in self._modal_ratios:
candidates = self._modal_sources_list[needed_modality]
cand_weights = self._modal_ratios[needed_modality]
# random.choices 返回的是 list,取 [0]
# 注意:这里 candidates 和 weights 的长度必须对齐,由 __init__ 保证
src_id, src_iter, modal = random.choices(candidates, weights=cand_weights, k=1)[0]
else:
# 无权重,纯随机
modal_sources = self._sources_by_modal[needed_modality]
src_id, src_iter, modal = random.choice(modal_sources)
else:
# 轮询模式
modal_sources = self._sources_by_modal[needed_modality]
src_id, src_iter, modal = random.choice(modal_sources)
# 3. 执行 next(),获取数据
# 如果 generator 耗尽,会在这里抛出 StopIteration
_t0_fetch = time.monotonic()
sample, file_idx, line_id = next(src_iter)
if self._perf:
self._perf.record("fetch_jsonl_next", time.monotonic() - _t0_fetch)
sample["src_name"] = src_id
# 4. 更新 Worker 的历史记录 (断点续训用)
self.workers2history_dict[self.worker_id][src_id]["file_idx"] = file_idx
self.workers2history_dict[self.worker_id][src_id]["line_id"] = line_id
# 5. 构造 data_state (扁平化的状态向量)
flat_hist = []
for sid in self.src_ids:
# 必须按 src_ids 的固定顺序拼接,否则 resume 会错位
pos = self.workers2history_dict[self.worker_id].get(sid, {"file_idx": 0, "line_id": 0})
flat_hist.append(pos["file_idx"])
flat_hist.append(pos["line_id"])
# data_state 格式: [worker_id, file1, line1, file2, line2, ...]
data_state = [self.worker_id] + flat_hist
items.append((sample, needed_modality, data_state))
except StopIteration:
# === 源耗尽清理逻辑 ===
# 1. 从主列表移除
self._sources = [(ii, it, mm) for (ii, it, mm) in self._sources if not (ii == src_id and it is src_iter)]
# 2. 从模态映射表移除
if hasattr(self, '_sources_by_modal'):
self._sources_by_modal[modal] = [x for x in self._sources_by_modal[modal] if not (x[0] == src_id and x[1] is src_iter)]
# 这里我们假设模态永远不会耗尽,所以不做多余处理
pass
# 当前这次 count 没取到数据,continue 继续尝试取下一个源
continue
except Exception as e:
print(f"[Fetch Error in Worker {self.worker_id}] {e}")
traceback.print_exc()
continue
return items
def _process_and_dispatch(self, args, modality, target_q):
"""
[新增函数] 在子线程中运行:生成数据 -> 校验 -> 直接 Put 进队列
这样主线程就不需要 acquire GIL 来做搬运工了。
"""
sample_obj, _, data_state = args
max_size = target_q.maxsize
try:
# 1. 执行繁重的 IO/CPU 生成工作
# 注意:这里我们强制把 generator 转为 list,确保在线程内执行完毕
results = list(self._build_out(sample_obj, modality))
# 2. 直接在线程内分发数据 (Distributed Putting)
for s in results:
# --- 基础校验逻辑 (从原代码迁移过来) ---
if modality == "text_to_audio" and s.get("audio_latents") is None: continue
if modality == "text_to_video" and (s["video_latents"] is None or s["video_latents"].shape[-1] == 1):
print(f"[Warn] {modality} return bad data, skip.")
continue
if modality == "text_to_image" and (s["image_latents"] is None or s["image_latents"].shape[-1] == 1):
print(f"[Warn] {modality} return bad data, skip.")
continue
if modality == "text_to_av" and (s["audio_latents"] is None or s["video_latents"] is None or s["video_latents"].shape[-1] == 1):
print(f"[Warn] {modality} return bad data, skip.")
continue
# 附加状态信息
s["data_state"] = torch.tensor(data_state, dtype=torch.long)
target_q.put(s)
except Exception as e:
pass
def _single_modality_producer_loop(self, modality):
"""
单模态生产者线程:只负责生产指定 modality 的数据。
"""
print(f"[DEBUG] Worker {self.worker_id}: Producer for [{modality}] STARTED")
try:
# 1. 错峰启动
time.sleep(random.uniform(0.1, 2.0))
self._init_sources_once()
target_queue = self.modality_queues[modality]
BASE_WORKERS = self.io_workers
MAX_WINDOW = int(BASE_WORKERS * 2.0)
flying_futures = []
while True:
# ==========================
# A. 优先处理已完成的任务,释放出队列压力
# ==========================
done, not_done = wait(flying_futures, timeout=0.01, return_when=FIRST_COMPLETED)
for f in done:
results, data_state, err = f.result()
if err is not None:
print(f"Error in future: {err}")
continue
for s in results:
# 校验逻辑
if modality == "text_to_audio" and s.get("audio_latents") is None: continue
if modality == "text_to_video" and (s.get("video_latents") is None or s["video_latents"].shape[-1] == 1): continue
if modality == "text_to_image" and (s.get("image_latents") is None or s["image_latents"].shape[-1] == 1): continue
if modality == "text_to_av" and (s.get("audio_latents") is None or s.get("video_latents") is None or s["video_latents"].shape[-1] == 1): continue
s["data_state"] = torch.tensor(data_state, dtype=torch.long)
# 【核心修复】:阻塞转移到模态的主线程,彻底释放 io_pool 资源
_t0_qput = time.monotonic()
target_queue.put(s)
_elapsed_qput = time.monotonic() - _t0_qput
if self._perf and _elapsed_qput > 0.01:
self._perf.record(f"producer_q_put_wait_{modality}", _elapsed_qput)
flying_futures = list(not_done)
# ==========================
# B. 检查队列是否满 (背压)
# ==========================
curr_size = target_queue.qsize()
max_size = target_queue.maxsize
pending_count = len(flying_futures)
if curr_size + pending_count >= max_size:
time.sleep(0.02)
continue
# ==========================
# C. 填补窗口
# ==========================
raw_items = self._fetch_raw_jsons(modality, count=1)
if not raw_items:
# 【核心修复2】: 耗尽时,必须等待剩余任务完成并入队,否则丢失尾部数据
for f in as_completed(flying_futures):
results, data_state, err = f.result()
if err is not None: continue
for s in results:
if s.get("captions") is None: continue
s["data_state"] = torch.tensor(data_state, dtype=torch.long)
target_queue.put(s)
break # 跳出 while,当前模态读取完毕
item = raw_items[0]
# 注意:这里改用 _process_item_concurrently
fut = self.io_pool.submit(self._process_item_concurrently, item)
flying_futures.append(fut)
if self._perf:
self._perf.maybe_print()
except Exception as e:
print(f"!!! FATAL ERROR IN PRODUCER [{modality}] (Worker {self.worker_id}): {e}")
traceback.print_exc()
finally:
# 【核心修复3】:正常或异常退出时,必须发 None 信号,否则消费者 get() 永久死锁
target_queue.put(None)
def _single_modality_producer_loop_old(self, modality):
"""
单模态生产者线程:只负责生产指定 modality 的数据。
"""
print(f"[DEBUG] Worker {self.worker_id}: Producer for [{modality}] STARTED")
try:
# 1. 错峰启动
time.sleep(random.uniform(0.1, 2.0))
self._init_sources_once()
# 2. 是否同步
synchronous = False
target_queue = self.modality_queues[modality]
# 3. 动态窗口参数 (仅针对 heavy IO)
BASE_WORKERS = self.io_workers
# 如果有多个 heavy producer (例如同时有图和视频),需要共享 worker 配额
MAX_WINDOW = int(BASE_WORKERS * 2.0)
flying_futures = []
while True:
# ==========================
# A. 检查队列是否满 (背压)
# ==========================
curr_size = target_queue.qsize()
max_size = target_queue.maxsize
if flying_futures:
flying_futures = [f for f in flying_futures if not f.done()]
pending_count = len(flying_futures)
# 核心逻辑:只有当 (库存 + 在途) < 总容量 时,才允许新生产。
# 这样保证了每一个发出的任务,回来时都有“停车位”,绝对不会卡住。
if curr_size + pending_count >= max_size:
# 此时已经饱和,必须休息。
# 由于 Worker 不会阻塞,它们会很快完成并释放资源。
time.sleep(0.02)
continue
# 2. 填补窗口
raw_items = self._fetch_raw_jsons(modality, count=1)
if not raw_items: break
item = raw_items[0]
# 提交任务
fut = self.io_pool.submit(self._process_and_dispatch, item, modality, target_queue)
flying_futures.append(fut)
if len(flying_futures) > MAX_WINDOW:
done, _ = wait(flying_futures, return_when=FIRST_COMPLETED)
flying_futures = [f for f in flying_futures if not f.done()]
# 3. 收割结果 (Sliding Window)
if not flying_futures:
time.sleep(0.01)
continue
except Exception as e:
print(f"!!! FATAL ERROR IN PRODUCER [{modality}] (Worker {self.worker_id}): {e}")
traceback.print_exc()
target_queue.put(None)
def _process_item_concurrently(self, args):
"""
这个函数会在子线程运行。
必须在这里把 generator 消耗完 (list化),否则 IO 操作可能延迟到主线程迭代时才发生。
"""
sample_obj, modal, data_state = args
try:
# 调用 _build_out (包含 BOS 下载/VAE 请求)
# 关键:使用 list() 强制立即执行生成器里的代码
_t0_build = time.monotonic()
results = list(self._build_out(sample_obj, modal))
if self._perf:
self._perf.record(f"build_out_{modal}", time.monotonic() - _t0_build)
for r in results:
if isinstance(r, dict):
for k in ("audio_latents", "video_latents", "image_latents"):
if k in r and r[k] is not None and hasattr(r[k], "element_size"):
self._perf.record(f"tensor_bytes_{k}", r[k].element_size() * r[k].numel())
return results, data_state, None # None 代表无异常
except Exception as e:
# 捕获异常,避免单个样本搞挂整个线程池
return [], data_state, e
# --- 辅助:根据 latent 长度计算所在桶编号 ---
def _get_length_bucket_id(self, length, boundaries):
"""根据 length 和预计算的 boundaries 返回桶编号 (0 ~ num_length_buckets-1)"""
idx = int(np.searchsorted(boundaries, length, side='right')) - 1
return max(0, min(idx, self.num_length_buckets - 1))
# --- 辅助:分发单个样本到桶 ---
def _distribute_single_item(self, out):
"""分发单个样本到桶。返回 True 已放入,False 表示目标桶已满(调用方应 put back Queue)。"""
# 音/图/视频
data_mode = None
if "audio_latents" in out and out["audio_latents"] is not None:
if "video_latents" in out and out["video_latents"] is not None:
data_mode = "text_to_av"
else:
data_mode = "text_to_audio"
elif "video_latents" in out and out["video_latents"] is not None:
data_mode = "text_to_video"
elif "image_latents" in out and out["image_latents"] is not None:
data_mode = "text_to_image"
if data_mode and data_mode in self.modality_mappers:
if self.use_length_buckets and data_mode in self.length_buckets_mapper:
# 按 latent 长度分桶
if data_mode == "text_to_audio":
seq_len = out["audio_latents"].shape[0]
bucket_id = self._get_length_bucket_id(seq_len, self._aud_bucket_boundaries)
elif data_mode == "text_to_video":
seq_len = out["video_latents"].shape[0]
bucket_id = self._get_length_bucket_id(seq_len, self._vid_bucket_boundaries)
else: # text_to_av — 以视频 latent 长度为主
seq_len = out["video_latents"].shape[0]
bucket_id = self._get_length_bucket_id(seq_len, self._vid_bucket_boundaries)
bucket = self.length_buckets_mapper[data_mode][bucket_id]
if self._per_bucket_cap > 0 and len(bucket) >= self._per_bucket_cap:
return False # 桶满,拒收
bucket.append(out)
elif self.use_aspect_ratio_buckets:
aspect_ratio = str(out.get('aspect_ratio', '1.0'))
if '.' in aspect_ratio and len(aspect_ratio) < 10: pass
self.aspect_ratio_buckets_mapper[data_mode][aspect_ratio].append(out)
else:
self.modality_buffers_mapper[data_mode].append(out)
return True
def _buckets_need_refill(self, mod):
"""检查是否还需要从 Queue 拉数据。
put-back 机制保证单桶不溢出,所以这里只在所有已有桶都满时才停止拉取。
此时 Queue 里全是满桶 item,继续拉也只会 put back,浪费 CPU。
"""
if not self.use_length_buckets or mod not in self.length_buckets_mapper or self._per_bucket_cap <= 0:
return True
mode_buckets = self.length_buckets_mapper[mod]
if not mode_buckets:
return True
has_ready = any(len(b) >= self.batch_size for b in mode_buckets.values())
if not has_ready:
return True # 没有桶能出 batch,必须继续拉
# 所有已有桶(有数据的)都满了 → 停止
non_empty = [b for b in mode_buckets.values() if len(b) > 0]
if non_empty and all(len(b) >= self._per_bucket_cap for b in non_empty):
return False
return True
def _check_buckets_and_pop_packed(self, target_modality):
bs = self.batch_size
# 长度分桶模式:按桶内样本数量做加权随机采样
if self.use_length_buckets and target_modality in self.length_buckets_mapper:
mode_buckets = self.length_buckets_mapper[target_modality]
if self.enable_ddp_bucket_sync and dist.is_initialized():
# --- DDP 同步路径 ---
# 协议:每次调用固定 1 次 broadcast + 1 次 all_reduce,保证不 hang。
buf = self._bucket_sync_buf
if not buf.is_cuda:
buf = buf.cuda()
self._bucket_sync_buf = buf
# Step 1: rank 0 按数量比例加权随机选桶,broadcast
if dist.get_rank() == 0:
ready_keys = [k for k, v in mode_buckets.items() if len(v) >= bs]
if ready_keys:
weights = [len(mode_buckets[k]) for k in ready_keys]
chosen_key = random.choices(ready_keys, weights=weights, k=1)[0]
buf[0] = chosen_key
else:
buf[0] = -1
dist.broadcast(buf, src=0)
chosen = buf[0].item()
# Step 2: 各 rank 判断自己能否出 batch
local_result = None
if chosen >= 0 and chosen in mode_buckets and len(mode_buckets[chosen]) >= bs:
local_result = chosen
else:
# fallback: 本地按比例随机选
ready_keys = [k for k, v in mode_buckets.items() if len(v) >= bs]
if ready_keys:
weights = [len(mode_buckets[k]) for k in ready_keys]
local_result = random.choices(ready_keys, weights=weights, k=1)[0]
# Step 3: all_reduce MIN 同步"所有 rank 都能出 batch"
can_serve = torch.tensor([1 if local_result is not None else 0],
dtype=torch.long, device=buf.device)
dist.all_reduce(can_serve, op=dist.ReduceOp.MIN)
if can_serve.item() == 0:
return None
return [mode_buckets[local_result].popleft() for _ in range(bs)]
else:
# --- 非同步路径:本地按数量比例加权随机选桶 ---
ready_keys = [k for k, v in mode_buckets.items() if len(v) >= bs]
if ready_keys:
weights = [len(mode_buckets[k]) for k in ready_keys]
chosen = random.choices(ready_keys, weights=weights, k=1)[0]
return [mode_buckets[chosen].popleft() for _ in range(bs)]
return None
bucket = self.modality_buffers_mapper[target_modality] # deque
if len(bucket) < bs:
return None
# O(bs),与 bucket 总长度无关
return [bucket.popleft() for _ in range(bs)]
def _monitor_loop(self):
"""
监控线程:打印所有模态队列的水位,以及消费者端的分桶积压情况。
"""
import time
time.sleep(5) # 启动后先等一会
while True:
time.sleep(30) # 每 30 秒打印一次
wid = getattr(self, 'worker_id', '?')
# === 1. 监控各模态 Queue (生产者进度) ===
# 这是 Worker 线程(Producer)产出后暂存的地方
queue_stats = []
if hasattr(self, "modality_queues"):
for m_name, q in self.modality_queues.items():
queue_stats.append(f"{m_name}={q.qsize()}")
else:
queue_stats.append("No_Queues")
queue_str = " | ".join(queue_stats)
# === 2. 监控 Buckets (消费者组装进度) ===
# 这是 Consumer 取出来后,分类暂存等待凑 Batch 的地方
# 我们需要详细看到每个模态的积压情况
buf_stats = []
# 音频/图像/视频桶 (区分是否开启长度分桶 或 AR 分桶)
if self.use_length_buckets:
# 长度分桶模式:打印各模态各桶的积压
for m_name, mode_buckets in self.length_buckets_mapper.items():
total = sum(len(b) for b in mode_buckets.values())
detail = ",".join(f"b{k}:{len(v)}" for k, v in sorted(mode_buckets.items()) if len(v) > 0)
buf_stats.append(f"{m_name}={total}({detail})")
# image 走普通 buffer
if "text_to_image" in self.modality_buffers_mapper:
buf_stats.append(f"text_to_image={len(self.modality_buffers_mapper['text_to_image'])}")
elif self.use_aspect_ratio_buckets:
# 遍历 mapper 下的所有模态 (text_to_image, image_und...)
for m_name, mode_buckets in self.aspect_ratio_buckets_mapper.items():
# 计算该模态下所有 AR 子桶的总和
total = sum(len(b) for b in mode_buckets.values())
buf_stats.append(f"{m_name}={total}")
else:
# 普通模式:直接遍历 mapper
for m_name, b in self.modality_buffers_mapper.items():
buf_stats.append(f"{m_name}={len(b)}")
buf_str = " | ".join(buf_stats)
# === 3. 打印 ===
# Q: Queue 水位 (Producer)
# Buf: Bucket 积压 (Consumer)
print(f"[MONITOR W-{wid}] Q:[{queue_str}] | Buf:[{buf_str}]")
if self._perf:
self._perf.maybe_print(force=True)
def _refill_buffer_nonblocking(self, mod, cap=256):
"""非阻塞补货:有多少拿多少,绝不等 Queue。桶满的 item 放回 Queue。"""
if not self._buckets_need_refill(mod):
return 0
q = self.modality_queues[mod]
moved = 0
putback_streak = 0
for _ in range(cap):
if not self._buckets_need_refill(mod):
break
try:
item = q.get_nowait()
except queue.Empty:
break
if item is None:
return -1 # 终止信号
if not self._distribute_single_item(item):
q.put(item) # 桶满,放回 Queue
putback_streak += 1
if putback_streak >= 3:
break # Queue 里大多是满桶 item,停止白转
else:
putback_streak = 0
moved += 1
return moved
def _refill_buffer_to_target(self, mod, target, cap):
"""把 mod 的 buffer 补到 target(最多搬 cap 个)。桶满的 item 放回 Queue。"""
q = self.modality_queues[mod]
if not self._buckets_need_refill(mod):
return 0
# 根据分桶模式选择正确的水位统计
if self.use_length_buckets and mod in self.length_buckets_mapper:
mode_buckets = self.length_buckets_mapper[mod]
cur = sum(len(b) for b in mode_buckets.values())
has_ready_bucket = any(len(b) >= self.batch_size for b in mode_buckets.values())
if cur >= target and has_ready_bucket:
return 0
else:
cur = len(self.modality_buffers_mapper[mod])
if cur >= target:
return 0
need = max(1, min(target - cur, cap))
moved = 0
# 1) 先阻塞拿 1 个,保证推进
_t0_bget = time.monotonic()
item = q.get()
if self._perf:
self._perf.record(f"consumer_q_get_blocking_{mod}", time.monotonic() - _t0_bget)
if item is None:
return -1 # 终止信号
if not self._distribute_single_item(item):
q.put(item) # 桶满,放回
else:
moved += 1
need -= 1
# 2) 再非阻塞补齐,水位够了就停
putback_streak = 0
for _ in range(need):
if not self._buckets_need_refill(mod):
break
try:
item = q.get_nowait()
except queue.Empty:
break
if item is None:
return -1
if not self._distribute_single_item(item):
q.put(item)
putback_streak += 1
if putback_streak >= 3:
break
else:
putback_streak = 0
moved += 1
return moved
def __iter__(self):
"""实现流式迭代"""
# 1. 确保后台线程只启动一次
if not self._producer_running:
self._producer_running = True
# 遍历权重不为0的所有模态,为每个模态启动一个线程
for modality in self.modality_queues.keys():
t = threading.Thread(
target=self._single_modality_producer_loop,
args=(modality,), # 传参
daemon=True
)
t.start()
self.producer_threads.append(t)
# 监控线程,debug时使用
t_mon = threading.Thread(target=self._monitor_loop, daemon=True)
t_mon.start()
produced = 0
batch_size = self.batch_size
target = batch_size * 16
cap = 256
modalities = sorted(list(self.modality_queues.keys()))
weights = [self.modality_configs[m]["weight"] for m in modalities]
# 2. 获取当前 Local Worker ID
worker_info = torch.utils.data.get_worker_info()
local_worker_id = worker_info.id if worker_info is not None else 0
# 3. 创建专属的随机数生成器。
# 这样不同 Rank 的 Worker 0 会生成相同的模态序列,Worker 1 也会生成相同的序列
# modality_rng = random.Random(2026 + local_worker_id)
current_seed = 2026 + local_worker_id * 1000 + self._local_iter_count
modality_rng = random.Random(current_seed)
self._local_iter_count += 1
# # 提前提取模态名和权重
# # modalities = list(self.modality_configs.keys())
# # weights = [self.modality_configs[m]["weight"] for m in modalities]
# modalities = list(self.modality_queues.keys()) # ✅ 只抽有队列的模态
# weights = [self.modality_configs[m]["weight"] for m in modalities]
# 主循环:只管从队列拿,不管怎么生产
_iter_batch_count = 0
_iter_t0 = time.monotonic()
while produced < self.upper_count:
_batch_assemble_t0 = time.monotonic()
target_modality = modality_rng.choices(modalities, weights=weights, k=1)[0]
batch_data = None
target_q = self.modality_queues[target_modality]
while batch_data is None:
# 1) 先查桶 —— buffer 有货就直接出,不阻塞等 Queue
start_time = time.time()
_t0_buf_hit = time.monotonic()
batch_data = self._check_buckets_and_pop_packed(target_modality)
if batch_data:
if self._perf:
self._perf.record("batch_from_buffer_hit", time.monotonic() - _t0_buf_hit)
# ★ 非阻塞补货:顺手把 Queue 里已有的搬到 buffer,绝不等
_t0_refill = time.monotonic()
moved = self._refill_buffer_nonblocking(target_modality, cap=cap)
if self._perf:
self._perf.record(f"refill_buffer_{target_modality}", time.monotonic() - _t0_refill)
break
# 2) buffer 空了,必须从 Queue 补货(阻塞)
start_time_1 = time.time()
buf_len = len(self.modality_buffers_mapper[target_modality])
_t0_refill2 = time.monotonic()
moved = self._refill_buffer_to_target(target_modality, target=target, cap=cap)
if self._perf:
self._perf.record(f"refill_buffer_{target_modality}", time.monotonic() - _t0_refill2)
if moved == -1:
return
# 3) 补完再 pop(补货的目标就是让这里大概率成功)
start_time_2 = time.time()
batch_data = self._check_buckets_and_pop_packed(target_modality)
if batch_data:
end_time_2 = time.time()
produced += self.batch_size
continue
# 4) 兜底:buffer 还不够(queue 太空/生产跟不上),阻塞再拿 1 个再试
# 如果桶水位已经够了(DDP 同步失败导致的重试),不要继续塞数据
if not self._buckets_need_refill(target_modality):
time.sleep(0.01) # 等其他 rank 追上
continue
_t0_fallback = time.monotonic()
item = self.modality_queues[target_modality].get()
if self._perf:
self._perf.record(f"consumer_fallback_get_{target_modality}", time.monotonic() - _t0_fallback)
if item is None:
return
if not self._distribute_single_item(item):
self.modality_queues[target_modality].put(item) # 桶满,放回
# batch_assemble_total timing
if self._perf:
self._perf.record("batch_assemble_total", time.monotonic() - _batch_assemble_t0)
# 5) yield
produced += self.batch_size
_iter_batch_count += 1
# 每 100 个 batch 打印迭代吞吐率
if self._perf and _iter_batch_count % 100 == 0:
_iter_elapsed = time.monotonic() - _iter_t0
print(f"[DATA_PERF {self._perf.worker_id}] iter throughput: "
f"{_iter_batch_count} batches in {_iter_elapsed:.1f}s "
f"({_iter_batch_count / _iter_elapsed:.2f} batch/s)", flush=True)
self._perf.maybe_print()
yield batch_data
def _build_out_aud(self, obj: dict):
samples = []
src_name = obj["src_name"]
data_type = src_name.split("_")[-1]
try:
if "audio_splits_info" in obj:
audio_splits_info = obj["audio_splits_info"]
text_list = obj["text_list"]
elif "video_info" in obj:
audio_splits_info = obj["video_info"]
text_list = obj["text_list"][:1]
else:
raise KeyError("Neither audio_splits_info nor video_info found in obj.")
# text_list = obj["text_list"]
assert len(text_list) == len(audio_splits_info), (
f"Length of text_list ({len(text_list)}) must equal "
f"length of audio_splits_info ({len(audio_splits_info)})."
)
for sample_audio_info, sample_text_info in zip(audio_splits_info, text_list):
# construct audio caption / speech content
text = sample_text_info["text"]
is_audio_split = "audio_info_idx" in sample_text_info
start = sample_text_info.get("media_start", None)
end = sample_text_info.get("media_end", None)
speech_starts = sample_text_info.get("speech_start", None)
speech_ends = sample_text_info.get("speech_end", None)
is_valid = sample_text_info.get("is_valid", True)
# encode audio
sample_spk_embs = []
data_path = sample_audio_info["data_path"]
duration = sample_audio_info.get("duration", None)
is_valid = sample_audio_info.get("is_valid", is_valid)
if not is_valid or duration > self.max_audio_duration or duration < self.min_audio_duration:
continue
if data_type == "tts":
if "<S>" not in text:
text = "<S>" + text + "<E>"
elif speech_starts and speech_ends:
num_speech_starts = text.count("<S>")
num_speech_ends = text.count("<E>")
assert num_speech_starts == num_speech_ends and num_speech_ends == len(speech_starts), \
f"Error: starts {num_speech_starts} not match with ends {num_speech_ends}, text: {text}"
if self.add_spk_emb:
if "<S>" in text and "<E>" in text:
text = text.replace("<S>", "<S><extra_id_2>")
if self.use_speech_special_token:
text = text.replace("<S>", "<extra_id_0>")
text = text.replace("<E>", "<extra_id_1>")
assert self.audio_vae_server, "audio server is not ready."
query = {
"data_path": data_path,
"use_spk_emb": (self.add_spk_emb and random.random() < self.spk_emb_prob) and data_type == "tts",
}
_t0_aud_enc = time.monotonic()
result = self.audio_vae_server.encode(
query, rank=self.dist_info.world_rank,
).latent_dist.sample()
if self._perf:
self._perf.record("vae_encode_audio", time.monotonic() - _t0_aud_enc)
audio_latents = result["audio_latents"][0].permute(1, 0) # [L, 20]
if self.add_spk_emb and data_type == "tts":
sample_spk_embs.append(result["spk_embs"]) # [b, d]
if self.add_spk_emb and speech_starts and speech_ends:
for speech_start, speech_end in zip(speech_starts, speech_ends):
query.update({
"start": speech_start,
"duration": speech_end - speech_start,
"use_spk_emb": self.add_spk_emb and random.random() < self.spk_emb_prob
})
# print(query, self.spk_emb_prob, 111)
_t0_spk_enc = time.monotonic()
result = self.audio_vae_server.encode(
query, rank=self.dist_info.world_rank,
).latent_dist.sample()
if self._perf:
self._perf.record("vae_encode_audio_spk", time.monotonic() - _t0_spk_enc)
sample_spk_embs.append(result["spk_embs"])
sample = {
"captions": text,
"audio_latents": audio_latents,
"spk_embs": sample_spk_embs,
}
samples.append(sample)
return samples
except Exception as e:
print(f"Error in building audio sample: {e}")
# print(f"obj: {obj}")
return []
def _build_out_img(self, obj: dict):
samples = []
aspect_ratio = "0/0"
try:
image_info = obj["image_info"]
text_list = obj["text_list"]
assert len(image_info) == len(text_list), (
f"Length of image_info ({len(image_info)}) must equal "
f"length of text_list ({len(text_list)})."
)
for sample_image_info, sample_text_info in zip(image_info, text_list):
text = sample_text_info["text"]
is_valid = sample_text_info.get("is_valid", True)
data_path = sample_image_info["data_path"]
image_height = sample_image_info["image_height"]
image_width = sample_image_info["image_width"]
is_valid = sample_image_info.get("is_valid", is_valid)
if not is_valid: # invalid sample
continue
assert self.image_vae_server, "image server is not ready."
_t0_img_enc = time.monotonic()
image_latents = self.image_vae_server.encode(
data_path, rank=self.dist_info.world_rank
).latent_dist.sample()
if self._perf:
self._perf.record("vae_encode_image", time.monotonic() - _t0_img_enc)
if self.use_aspect_ratio_buckets:
aspect_ratio = compute_tgt_ratio(image_height, image_width)
_, h, w, _ = image_latents.shape
w_ratio, h_ratio = [int(i) for i in aspect_ratio.split("/")]
if (h < w and h_ratio >= w_ratio) or (h > w and h_ratio <= w_ratio):
aspect_ratio = f"{h_ratio}/{w_ratio}"
sample = {
"captions": text,
"image_latents": image_latents,
"aspect_ratio": aspect_ratio,
}
samples.append(sample)
return samples
except Exception as e:
print(f"Error in building image sample: {e}")
print(f"obj: {obj}")
return []
def _build_out_vid(self, obj: dict):
samples = []
aspect_ratio = "0, 0/0"
try:
video_info = obj["video_info"]
text_list = obj["text_list"]
assert len(video_info) == len(text_list), (
f"Length of video_info ({len(video_info)}) must equal "
f"length of text_list ({len(text_list)})."
)
for sample_video_info, sample_text_info in zip(video_info, text_list):
text = sample_text_info["text"]
# text = filter_video_descriptions(text)
is_valid = sample_text_info.get("is_valid", True)
data_path = sample_video_info["data_path"]
video_height = sample_video_info["image_height"]
video_width = sample_video_info["image_width"]
video_duration = sample_video_info["duration"]
video_frames = int(video_duration * self.video_fps)
is_valid = sample_video_info.get("is_valid", is_valid)
if not is_valid:
print(f"invalid video sample: {obj}")
continue
if video_frames < self.video_min_frames or video_frames > self.video_max_frames:
print(f"skip video due to frame number {video_frames}, \
not in [{self.video_min_frames}, {self.video_max_frames}]")
continue
if self.use_aspect_ratio_buckets:
aspect_ratio = compute_tgt_ratio(video_height, video_width)
assert self.video_vae_server, "video server is not ready."
_t0_vid_enc = time.monotonic()
video_latents = self.video_vae_server.encode(
data_path,
rank=self.dist_info.world_rank,
frame_length=self.video_tgt_frames,
fps=self.video_fps
).latent_dist.sample()
if self._perf:
self._perf.record("vae_encode_video", time.monotonic() - _t0_vid_enc)
sample = {
"captions": text,
"video_latents": video_latents,
"aspect_ratio": aspect_ratio,
}
samples.append(sample)
return samples
except Exception as e:
print(f"Error in building video sample: {e}")
traceback.print_exc()
return []
def _build_out_av(self, obj: dict):
# notes: fake av data with video data
samples = []
aspect_ratio = "0, 0/0"
try:
video_info = obj["video_info"]
text_list = obj["text_list"]
assert len(video_info) == len(text_list), (
f"Length of video_info ({len(video_info)}) must equal "
f"length of text_list ({len(text_list)})."
)
for s_idx, (sample_video_info, sample_text_info) in enumerate(zip(video_info, text_list)):
text = sample_text_info["text"]
# text = filter_video_descriptions(text)
is_valid = sample_text_info.get("is_valid", True)
sample_spk_embs = []
speech_starts = sample_text_info.get("speech_start", None)
speech_ends = sample_text_info.get("speech_end", None)
if speech_starts and speech_ends:
num_speech_starts = text.count("<S>")
num_speech_ends = text.count("<E>")
assert num_speech_starts == num_speech_ends and num_speech_ends == len(speech_starts), \
f"Error: starts {num_speech_starts} not match with ends {num_speech_ends}, text: {text}"
if self.add_spk_emb:
if "<S>" in text and "<E>" in text:
text = text.replace("<S>", "<S><extra_id_2>")
if self.use_speech_special_token:
text = text.replace("<S>", "<extra_id_0>")
text = text.replace("<E>", "<extra_id_1>")
data_path = sample_video_info["data_path"]
if self.split_wav_mode:
if "audio_info" not in obj:
print("audio_info not found in obj for split_wav_mode")
continue
audio_info = obj["audio_info"]
audio_data_path = audio_info[s_idx]["data_path"]
else:
audio_data_path = data_path
video_height = sample_video_info["image_height"]
video_width = sample_video_info["image_width"]
video_duration = sample_video_info["duration"]
video_frames = int(video_duration * self.video_fps)
is_valid = sample_video_info.get("is_valid", is_valid)
if not is_valid:
print(f"invalid video sample: {obj}")
continue
if video_frames < self.video_min_frames or video_frames > self.video_max_frames:
print(f"skip video due to frame number {video_frames}, not in [{self.video_min_frames}, {self.video_max_frames}]")
continue
if self.use_aspect_ratio_buckets:
aspect_ratio = compute_tgt_ratio(video_height, video_width)
assert self.audio_vae_server and self.video_vae_server, "video server or audio server is not ready."
_t0_av_vid_enc = time.monotonic()
video_latents = self.video_vae_server.encode(
data_path,
rank=self.dist_info.world_rank,
frame_length=self.video_tgt_frames,
fps=self.video_fps
).latent_dist.sample()
if self._perf:
self._perf.record("vae_encode_video_av", time.monotonic() - _t0_av_vid_enc)
# 对齐视频音频 这里假设视频从一开始取
video_duration = ((video_latents.shape[0] - 1) * 4 + 1) / self.video_fps
audio_length = math.ceil(video_duration * self.audio_tokens_per_sec)
query = {
"data_path": audio_data_path,
"add_spk_emb": False,
"target_length": audio_length,
}
_t0_av_aud_enc = time.monotonic()
audio_result = self.audio_vae_server.encode(
query, rank=self.dist_info.world_rank,
).latent_dist.sample()
if self._perf:
self._perf.record("vae_encode_audio_av", time.monotonic() - _t0_av_aud_enc)
audio_latents = audio_result["audio_latents"][0].permute(1, 0)
if self.add_spk_emb and speech_starts and speech_ends:
for speech_start, speech_end in zip(speech_starts, speech_ends):
query.pop("target_length", None)
query.update({
"start": speech_start,
"duration": speech_end - speech_start,
"use_spk_emb": self.add_spk_emb and random.random() < self.spk_emb_prob
})
_t0_av_spk_enc = time.monotonic()
result = self.audio_vae_server.encode(
query, rank=self.dist_info.world_rank,
).latent_dist.sample()
if self._perf:
self._perf.record("vae_encode_audio_spk", time.monotonic() - _t0_av_spk_enc)
sample_spk_embs.append(result["spk_embs"])
# _t0_av_vid_enc = time.monotonic()
# video_latents = self.video_vae_server.encode(
# data_path,
# rank=self.dist_info.world_rank,
# frame_length=self.video_tgt_frames,
# fps=self.video_fps
# ).latent_dist.sample()
# if self._perf:
# self._perf.record("vae_encode_video_av", time.monotonic() - _t0_av_vid_enc)
# # 对齐视频音频 这里假设视频从一开始取
# video_duration = ((video_latents.shape[0] - 1) * 4 + 1) / self.video_fps
# audio_length = math.ceil(video_duration * self.audio_tokens_per_sec)
if audio_length > audio_latents.shape[0]:
print(f"{audio_length}, {audio_latents.shape}, not equal !!!!, {data_path}")
audio_latents = torch.cat([audio_latents, torch.zeros(size=(audio_length-audio_latents.shape[0], audio_latents.shape[1]), device=audio_latents.device)], dim=0)
else:
# print(f"{audio_length}, {audio_latents.shape}, is equal !!!!, {bos_url}")
audio_latents = audio_latents[:audio_length]
samples = {
"captions": text,
"audio_latents": audio_latents,
"spk_embs": sample_spk_embs,
"video_latents": video_latents,
"aspect_ratio": aspect_ratio,
}
return samples
except Exception as e:
print(f"Error in building video sample: {e}")
traceback.print_exc()
return []
def _build_out(self, obj: dict, chosen_modality: str):
# 使用滑动窗口方法,规划 bsz
"""构建输出"""
if chosen_modality == "text_to_audio":
aud_sample = self._build_out_aud(obj)
if isinstance(aud_sample, list):
yield from aud_sample
else:
yield aud_sample
elif chosen_modality == "text_to_image":
img_sample = self._build_out_img(obj)
if isinstance(img_sample, list):
yield from img_sample
else:
yield img_sample
elif chosen_modality == "text_to_video":
vid_sample = self._build_out_vid(obj)
if isinstance(vid_sample, list):
yield from vid_sample
else:
yield vid_sample
elif chosen_modality == "text_to_av":
av_sample = self._build_out_av(obj)
if isinstance(av_sample, list):
yield from av_sample
else:
yield av_sample
else:
print(f"not support {chosen_modality}")
raise NotImplementedError
def collate_fn(batch):
out = {}
process_keys = {
"captions",
"audio_latents",
"image_latents",
"video_latents",
"spk_embs",
"data_state",
"audio_seq_len",
"t_h_w_list",
}
for k in process_keys:
vals = [b.get(k, None) for b in batch]
if all(x is None for x in vals):
vals = None
out[k] = vals
if out["audio_latents"]:
out["audio_seq_len"] = [
b["audio_latents"].shape[-1] if b["audio_latents"] is not None else 0 for b in batch
]
if out["image_latents"]:
out["t_h_w_list"] = [
tuple(b["image_latents"].shape[:3]) if b["image_latents"] is not None else (0, 0, 0) for b in batch
]
if out["video_latents"]:
out["t_h_w_list"] = [
tuple(b["video_latents"].shape[:3]) if b["video_latents"] is not None else (0, 0, 0) for b in batch
]
# if out["spk_embs"]: # keep list
# out["spk_embs"] = torch.cat([b["spk_embs"] for b in batch], dim=0)
out["data_state"] = torch.stack(out["data_state"], dim=0)
return out
def collate_fn_batch(batchs):
"""
批量处理函数,将多个批次数据分发到单个collate_fn函数处理
"""
return [collate_fn(batch) for batch in batchs]
def main():
import yaml
import traceback
cfg = yaml.safe_load(open("configs/nava.yaml", "r"))
device = "cuda"
from nava_src.models.nava.utils.model_loading_utils import init_wan_vae_2_2
from nava_src.vae.local_video_vae import LocalVideoVAEAdapter
import torch
wan_vae = init_wan_vae_2_2(cfg["model"]["ckpt_dir"], rank=device)
wan_vae.model.requires_grad_(False).eval()
wan_vae.model = wan_vae.model.to(torch.bfloat16)
video_vae_server = LocalVideoVAEAdapter(wan_vae, resolution=cfg["image_size"])
from nava_src.vae.local_audio_vae import LocalAudioVAEAdapter, init_ltx_vae
ltx_vae = init_ltx_vae(cfg["model"]["audio_vae_ckpt_dir"], device=device)
audio_vae_server = LocalAudioVAEAdapter(ltx_vae, spk_model=None, sample_rate=16000)
image_vae_server = None
data = []
with open(cfg["data"]["data_filelist"], "r") as f:
for item in f.read().split('\n'):
if not item: continue
if len(item.split('\t')) == 3:
idx, name, path = item.split('\t')
data.append([name, path])
elif len(item.split('\t')) == 2:
idx, path = item.split('\t')
data.append(([path]))
else:
assert False
src_id2ratios = {}
with open(cfg["data"]["data_weights"], "r") as f:
for item in f.read().split("\n"):
if not item:
continue
key, value, modal = item.split("\t")
src_id2ratios[key] = [float(value), modal]
ds = AudioVideoDataset(
jsonl_or_src_list=data,
dist_info=DistInfo(world_rank=0, world_size=1),
batch_size=cfg['batch_size'],
queue_size=cfg["data"].get("queue_size", 5),
io_workers=cfg["data"].get("io_workers", 16),
src_id2ratios=src_id2ratios,
modal_prob=cfg["data"]["modal_prob"],
num_shards=1,
use_aspect_ratio_buckets=cfg["data"].get("use_aspect_ratio_buckets", False),
use_length_buckets=cfg["data"].get("use_length_buckets", False),
num_length_buckets=cfg["data"].get("num_length_buckets", 10),
enable_ddp_bucket_sync=cfg["data"].get("enable_ddp_bucket_sync", False),
audio_vae_server=audio_vae_server,
image_vae_server=image_vae_server,
video_vae_server=video_vae_server,
min_audio_duration=cfg["data"].get("min_audio_duration", 0.0),
max_audio_duration=cfg["data"].get("max_audio_duration", 15.0),
video_min_frames=cfg["data"].get("video_min_frames", 1),
video_max_frames=cfg["data"].get("video_max_frames", 5),
video_tgt_frames=cfg["data"].get("video_tgt_frames", -1),
video_fps=cfg["data"].get("video_fps", 8),
add_spk_emb=cfg["data"].get("add_spk_emb", False),
spk_emb_prob=cfg["data"].get("spk_emb_prob", 0.9),
use_speech_special_token=cfg["data"].get("use_speech_special_token", False),
data_file_divisor=cfg["data"].get("data_file_divisor", 1),
)
dl = DataLoader(
ds,
batch_size=1,
shuffle=False,
num_workers=0, # 先不启多进程 worker,逻辑确认 OK 再开
collate_fn=partial(collate_fn_batch),
)
for step, batch in enumerate(dl):
batch = batch[0]
print(f"step {step}")
print(" data_state:", batch["data_state"])
print("-" * 40)
if __name__ == "__main__":
main()