import os import pickle import math import shutil import numpy as np import lmdb as lmdb import pandas as pd import torch import glob import json from dataloaders.build_vocab import Vocab from termcolor import colored from loguru import logger from collections import defaultdict from torch.utils.data import Dataset import torch.distributed as dist import pickle import smplx from .utils.audio_features import process_audio_data from .data_tools import joints_list from .utils.other_tools import MultiLMDBManager from .utils.motion_rep_transfer import process_smplx_motion from .utils.mis_features import process_semantic_data, process_emotion_data from .utils.text_features import process_word_data from .utils.data_sample import sample_from_clip import time class CustomDataset(Dataset): def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True): self.args = args self.loader_type = loader_type # Set rank safely - handle cases where distributed training is not yet initialized try: if torch.distributed.is_initialized(): self.rank = torch.distributed.get_rank() else: self.rank = 0 except: self.rank = 0 self.ori_stride = self.args.stride self.ori_length = self.args.pose_length # Initialize basic parameters self.ori_stride = self.args.stride self.ori_length = self.args.pose_length self.alignment = [0,0] # for trinity """Initialize SMPLX model.""" self.smplx = smplx.create( self.args.data_path_1+"smplx_models/", model_type='smplx', gender='NEUTRAL_2020', use_face_contour=False, num_betas=300, num_expression_coeffs=100, ext='npz', use_pca=False, ).cuda().eval() if self.args.word_rep is not None: with open(f"{self.args.data_path}weights/vocab.pkl", 'rb') as f: self.lang_model = pickle.load(f) # Load and process split rules self._process_split_rules() # Initialize data directories and lengths self._init_data_paths() if self.args.beat_align: if not os.path.exists(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy"): self.calculate_mean_velocity(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy") self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy") # Build or load cache self._init_cache(build_cache) def _process_split_rules(self): """Process dataset split rules.""" split_rule = pd.read_csv(self.args.data_path+"train_test_split.csv") self.selected_file = split_rule.loc[ (split_rule['type'] == self.loader_type) & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers)) ] if self.args.additional_data and self.loader_type == 'train': split_b = split_rule.loc[ (split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers)) ] self.selected_file = pd.concat([self.selected_file, split_b]) if self.selected_file.empty: logger.warning(f"{self.loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead") self.selected_file = split_rule.loc[ (split_rule['type'] == 'train') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers)) ] self.selected_file = self.selected_file.iloc[0:8] def _init_data_paths(self): """Initialize data directories and lengths.""" self.data_dir = self.args.data_path if self.loader_type == "test": self.args.multi_length_training = [1.0] self.max_length = int(self.args.pose_length * self.args.multi_length_training[-1]) self.max_audio_pre_len = math.floor(self.args.pose_length / self.args.pose_fps * self.args.audio_sr) if self.max_audio_pre_len > self.args.test_length * self.args.audio_sr: self.max_audio_pre_len = self.args.test_length * self.args.audio_sr if self.args.test_clip and self.loader_type == "test": self.preloaded_dir = self.args.root_path + self.args.cache_path + self.loader_type + "_clip" + f"/{self.args.pose_rep}_cache" else: self.preloaded_dir = self.args.root_path + self.args.cache_path + self.loader_type + f"/{self.args.pose_rep}_cache" def _init_cache(self, build_cache): """Initialize or build cache.""" self.lmdb_envs = {} self.mapping_data = None if build_cache and self.rank == 0: self.build_cache(self.preloaded_dir) # In DDP mode, ensure all processes wait for cache building to complete if torch.distributed.is_initialized(): torch.distributed.barrier() # Try to regenerate cache if corrupted (only on rank 0 to avoid race conditions) if self.rank == 0: self.regenerate_cache_if_corrupted() # Wait for cache regeneration to complete if torch.distributed.is_initialized(): torch.distributed.barrier() self.load_db_mapping() def build_cache(self, preloaded_dir): """Build the dataset cache.""" logger.info(f"Audio bit rate: {self.args.audio_fps}") logger.info("Reading data '{}'...".format(self.data_dir)) logger.info("Creating the dataset cache...") if self.args.new_cache and os.path.exists(preloaded_dir): shutil.rmtree(preloaded_dir) if os.path.exists(preloaded_dir): # if the dir is empty, that means we still need to build the cache if not os.listdir(preloaded_dir): self.cache_generation( preloaded_dir, self.args.disable_filtering, self.args.clean_first_seconds, self.args.clean_final_seconds, is_test=False ) else: logger.info("Found the cache {}".format(preloaded_dir)) elif self.loader_type == "test": self.cache_generation(preloaded_dir, True, 0, 0, is_test=True) else: self.cache_generation( preloaded_dir, self.args.disable_filtering, self.args.clean_first_seconds, self.args.clean_final_seconds, is_test=False ) def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False): """Generate cache for the dataset.""" if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir) # Initialize the multi-LMDB manager lmdb_manager = MultiLMDBManager(out_lmdb_dir, max_db_size=10*1024*1024*1024) self.n_out_samples = 0 n_filtered_out = defaultdict(int) for index, file_name in self.selected_file.iterrows(): f_name = file_name["id"] ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh" pose_file = os.path.join(self.data_dir, self.args.pose_rep, f_name + ext) # Process data data = self._process_file_data(f_name, pose_file, ext) if data is None: continue # Sample from clip filtered_result, self.n_out_samples = sample_from_clip( lmdb_manager=lmdb_manager, audio_file=pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav"), audio_each_file=data['audio'], pose_each_file=data['pose'], trans_each_file=data['trans'], trans_v_each_file=data['trans_v'], shape_each_file=data['shape'], facial_each_file=data['facial'], word_each_file=data['word'], vid_each_file=data['vid'], emo_each_file=data['emo'], sem_each_file=data['sem'], args=self.args, ori_stride=self.ori_stride, ori_length=self.ori_length, disable_filtering=disable_filtering, clean_first_seconds=clean_first_seconds, clean_final_seconds=clean_final_seconds, is_test=is_test, n_out_samples=self.n_out_samples ) for type_key in filtered_result: n_filtered_out[type_key] += filtered_result[type_key] lmdb_manager.close() def _process_file_data(self, f_name, pose_file, ext): """Process all data for a single file.""" data = { 'pose': None, 'trans': None, 'trans_v': None, 'shape': None, 'audio': None, 'facial': None, 'word': None, 'emo': None, 'sem': None, 'vid': None } # Process motion data logger.info(colored(f"# ---- Building cache for Pose {f_name} ---- #", "blue")) if "smplx" in self.args.pose_rep: motion_data = process_smplx_motion(pose_file, self.smplx, self.args.pose_fps, self.args.facial_rep) else: raise ValueError(f"Unknown pose representation '{self.args.pose_rep}'.") if motion_data is None: return None data.update(motion_data) # Process speaker ID if self.args.id_rep is not None: speaker_id = int(f_name.split("_")[0]) - 1 data['vid'] = np.repeat(np.array(speaker_id).reshape(1, 1), data['pose'].shape[0], axis=0) else: data['vid'] = np.array([-1]) # Process audio if needed if self.args.audio_rep is not None: audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav") data = process_audio_data(audio_file, self.args, data, f_name, self.selected_file) if data is None: return None # Process emotion if needed if self.args.emo_rep is not None: data = process_emotion_data(f_name, data, self.args) if data is None: return None # Process word data if needed if self.args.word_rep is not None: word_file = f"{self.data_dir}{self.args.word_rep}/{f_name}.TextGrid" data = process_word_data(self.data_dir, word_file, self.args, data, f_name, self.selected_file, self.lang_model) if data is None: return None # Process semantic data if needed if self.args.sem_rep is not None: sem_file = f"{self.data_dir}{self.args.sem_rep}/{f_name}.txt" data = process_semantic_data(sem_file, self.args, data, f_name) if data is None: return None return data def load_db_mapping(self): """Load database mapping from file.""" mapping_path = os.path.join(self.preloaded_dir, "sample_db_mapping.pkl") backup_path = os.path.join(self.preloaded_dir, "sample_db_mapping_backup.pkl") # Check if file exists and is readable if not os.path.exists(mapping_path): raise FileNotFoundError(f"Mapping file not found: {mapping_path}") # Check file size to ensure it's not empty file_size = os.path.getsize(mapping_path) if file_size == 0: raise ValueError(f"Mapping file is empty: {mapping_path}") print(f"Loading mapping file: {mapping_path} (size: {file_size} bytes)") # Add error handling and retry logic for pickle loading max_retries = 3 for attempt in range(max_retries): try: with open(mapping_path, 'rb') as f: self.mapping_data = pickle.load(f) print(f"Successfully loaded mapping data with {len(self.mapping_data.get('mapping', []))} samples") break except (EOFError, pickle.UnpicklingError) as e: if attempt < max_retries - 1: print(f"Warning: Failed to load pickle file (attempt {attempt + 1}/{max_retries}): {e}") print(f"File path: {mapping_path}") # Try backup file if main file is corrupted if os.path.exists(backup_path) and os.path.getsize(backup_path) > 0: print("Trying backup file...") try: with open(backup_path, 'rb') as f: self.mapping_data = pickle.load(f) print(f"Successfully loaded mapping data from backup with {len(self.mapping_data.get('mapping', []))} samples") break except Exception as backup_e: print(f"Backup file also failed: {backup_e}") print("Retrying...") time.sleep(1) # Wait a bit before retrying else: print(f"Error: Failed to load pickle file after {max_retries} attempts: {e}") print(f"File path: {mapping_path}") print("Please check if the file is corrupted or incomplete.") print("You may need to regenerate the cache files.") raise # Update paths from test to test_clip if needed if self.loader_type == "test" and self.args.test_clip: updated_paths = [] for path in self.mapping_data['db_paths']: updated_path = path.replace("test/", "test_clip/") updated_paths.append(updated_path) self.mapping_data['db_paths'] = updated_paths # In DDP mode, avoid modifying shared files to prevent race conditions # Instead, just update the in-memory data print(f"Updated test paths for test_clip mode (avoiding file modification in DDP)") self.n_samples = len(self.mapping_data['mapping']) def get_lmdb_env(self, db_idx): """Get LMDB environment for given database index.""" if db_idx not in self.lmdb_envs: db_path = self.mapping_data['db_paths'][db_idx] self.lmdb_envs[db_idx] = lmdb.open(db_path, readonly=True, lock=False) return self.lmdb_envs[db_idx] def __len__(self): """Return the total number of samples in the dataset.""" return self.n_samples def __getitem__(self, idx): """Get a single sample from the dataset.""" db_idx = self.mapping_data['mapping'][idx] lmdb_env = self.get_lmdb_env(db_idx) with lmdb_env.begin(write=False) as txn: key = "{:008d}".format(idx).encode("ascii") sample = txn.get(key) sample = pickle.loads(sample) tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans, trans_v, audio_name = sample # Convert data to tensors with appropriate types processed_data = self._convert_to_tensors( tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans, trans_v ) processed_data['audio_name'] = audio_name return processed_data def _convert_to_tensors(self, tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans, trans_v): """Convert numpy arrays to tensors with appropriate types.""" data = { 'emo': torch.from_numpy(emo).int(), 'sem': torch.from_numpy(sem).float(), 'audio_onset': torch.from_numpy(in_audio).float(), 'word': torch.from_numpy(in_word).int() } if self.loader_type == "test": data.update({ 'pose': torch.from_numpy(tar_pose).float(), 'trans': torch.from_numpy(trans).float(), 'trans_v': torch.from_numpy(trans_v).float(), 'facial': torch.from_numpy(in_facial).float(), 'id': torch.from_numpy(vid).float(), 'beta': torch.from_numpy(in_shape).float() }) else: data.update({ 'pose': torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float(), 'trans': torch.from_numpy(trans).reshape((trans.shape[0], -1)).float(), 'trans_v': torch.from_numpy(trans_v).reshape((trans_v.shape[0], -1)).float(), 'facial': torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float(), 'id': torch.from_numpy(vid).reshape((vid.shape[0], -1)).float(), 'beta': torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float() }) return data def regenerate_cache_if_corrupted(self): """Regenerate cache if the pickle file is corrupted.""" mapping_path = os.path.join(self.preloaded_dir, "sample_db_mapping.pkl") if os.path.exists(mapping_path): try: # Try to load the file to check if it's corrupted with open(mapping_path, 'rb') as f: test_data = pickle.load(f) return False # File is not corrupted except (EOFError, pickle.UnpicklingError): print(f"Detected corrupted pickle file: {mapping_path}") print("Regenerating cache...") # Remove corrupted file os.remove(mapping_path) # Regenerate cache self.build_cache(self.preloaded_dir) return True return False