Spaces:
Running
on
Zero
Running
on
Zero
| from pathlib import Path | |
| import torch | |
| from einops import rearrange | |
| from torch import Tensor | |
| from torch.utils.data import Dataset | |
| from ltx_trainer import logger | |
| # Constants for precomputed data directories | |
| PRECOMPUTED_DIR_NAME = ".precomputed" | |
| class DummyDataset(Dataset): | |
| """Produce random latents and prompt embeddings. For minimal demonstration and benchmarking purposes""" | |
| def __init__( | |
| self, | |
| width: int = 1024, | |
| height: int = 1024, | |
| num_frames: int = 25, | |
| fps: int = 24, | |
| dataset_length: int = 200, | |
| latent_dim: int = 128, | |
| latent_spatial_compression_ratio: int = 32, | |
| latent_temporal_compression_ratio: int = 8, | |
| prompt_embed_dim: int = 4096, | |
| prompt_sequence_length: int = 256, | |
| ) -> None: | |
| if width % 32 != 0: | |
| raise ValueError(f"Width must be divisible by 32, got {width=}") | |
| if height % 32 != 0: | |
| raise ValueError(f"Height must be divisible by 32, got {height=}") | |
| if num_frames % 8 != 1: | |
| raise ValueError(f"Number of frames must have a remainder of 1 when divided by 8, got {num_frames=}") | |
| self.width = width | |
| self.height = height | |
| self.num_frames = num_frames | |
| self.fps = fps | |
| self.dataset_length = dataset_length | |
| self.latent_dim = latent_dim | |
| self.num_latent_frames = (num_frames - 1) // latent_temporal_compression_ratio + 1 | |
| self.latent_height = height // latent_spatial_compression_ratio | |
| self.latent_width = width // latent_spatial_compression_ratio | |
| self.latent_sequence_length = self.num_latent_frames * self.latent_height * self.latent_width | |
| self.prompt_embed_dim = prompt_embed_dim | |
| self.prompt_sequence_length = prompt_sequence_length | |
| def __len__(self) -> int: | |
| return self.dataset_length | |
| def __getitem__(self, idx: int) -> dict[str, dict[str, Tensor]]: | |
| return { | |
| "latent_conditions": { | |
| "latents": torch.randn( | |
| self.latent_dim, | |
| self.num_latent_frames, | |
| self.latent_height, | |
| self.latent_width, | |
| ), | |
| "num_frames": self.num_latent_frames, | |
| "height": self.latent_height, | |
| "width": self.latent_width, | |
| "fps": self.fps, | |
| }, | |
| "text_conditions": { | |
| "prompt_embeds": torch.randn( | |
| self.prompt_sequence_length, | |
| self.prompt_embed_dim, | |
| ), # random text embeddings | |
| "prompt_attention_mask": torch.ones( | |
| self.prompt_sequence_length, | |
| dtype=torch.bool, | |
| ), # random attention mask | |
| }, | |
| } | |
| class PrecomputedDataset(Dataset): | |
| def __init__(self, data_root: str, data_sources: dict[str, str] | list[str] | None = None) -> None: | |
| """ | |
| Generic dataset for loading precomputed data from multiple sources. | |
| Args: | |
| data_root: Root directory containing preprocessed data | |
| data_sources: Either: | |
| - Dict mapping directory names to output keys | |
| - List of directory names (keys will equal values) | |
| - None (defaults to ["latents", "conditions"]) | |
| Example: | |
| # Standard mode (list) | |
| dataset = PrecomputedDataset("data/", ["latents", "conditions"]) | |
| # Standard mode (dict) | |
| dataset = PrecomputedDataset("data/", {"latents": "latent_conditions", "conditions": "text_conditions"}) | |
| # IC-LoRA mode | |
| dataset = PrecomputedDataset("data/", ["latents", "conditions", "reference_latents"]) | |
| Note: | |
| Latents are always returned in non-patchified format [C, F, H, W]. | |
| Legacy patchified format [seq_len, C] is automatically converted. | |
| """ | |
| super().__init__() | |
| self.data_root = self._setup_data_root(data_root) | |
| self.data_sources = self._normalize_data_sources(data_sources) | |
| self.source_paths = self._setup_source_paths() | |
| self.sample_files = self._discover_samples() | |
| self._validate_setup() | |
| def _setup_data_root(data_root: str) -> Path: | |
| """Setup and validate the data root directory.""" | |
| data_root = Path(data_root).expanduser().resolve() | |
| if not data_root.exists(): | |
| raise FileNotFoundError(f"Data root directory does not exist: {data_root}") | |
| # If the given path is the dataset root, use the precomputed subdirectory | |
| if (data_root / PRECOMPUTED_DIR_NAME).exists(): | |
| data_root = data_root / PRECOMPUTED_DIR_NAME | |
| return data_root | |
| def _normalize_data_sources(data_sources: dict[str, str] | list[str] | None) -> dict[str, str]: | |
| """Normalize data_sources input to a consistent dict format.""" | |
| if data_sources is None: | |
| # Default sources | |
| return {"latents": "latent_conditions", "conditions": "text_conditions"} | |
| elif isinstance(data_sources, list): | |
| # Convert list to dict where keys equal values | |
| return {source: source for source in data_sources} | |
| elif isinstance(data_sources, dict): | |
| return data_sources.copy() | |
| else: | |
| raise TypeError(f"data_sources must be dict, list, or None, got {type(data_sources)}") | |
| def _setup_source_paths(self) -> dict[str, Path]: | |
| """Map data source names to their actual directory paths.""" | |
| source_paths = {} | |
| for dir_name in self.data_sources: | |
| source_path = self.data_root / dir_name | |
| source_paths[dir_name] = source_path | |
| # Check that all sources exist. | |
| if not source_path.exists(): | |
| raise FileNotFoundError(f"Required {dir_name} directory does not exist: {source_path}") | |
| return source_paths | |
| def _discover_samples(self) -> dict[str, list[Path]]: | |
| """Discover all valid sample files across all data sources.""" | |
| # Use first data source as the reference to discover samples | |
| data_key = "latents" if "latents" in self.data_sources else next(iter(self.data_sources.keys())) | |
| data_path = self.source_paths[data_key] | |
| data_files = list(data_path.glob("**/*.pt")) | |
| if not data_files: | |
| raise ValueError(f"No data files found in {data_path}") | |
| # Initialize sample files dict | |
| sample_files = {output_key: [] for output_key in self.data_sources.values()} | |
| # For each data file, find corresponding files in other sources | |
| for data_file in data_files: | |
| rel_path = data_file.relative_to(data_path) | |
| # Check if corresponding files exist in ALL sources | |
| if self._all_source_files_exist(data_file, rel_path): | |
| self._fill_sample_data_files(data_file, rel_path, sample_files) | |
| return sample_files | |
| def _all_source_files_exist(self, data_file: Path, rel_path: Path) -> bool: | |
| """Check if corresponding files exist in all data sources.""" | |
| for dir_name in self.data_sources: | |
| expected_path = self._get_expected_file_path(dir_name, data_file, rel_path) | |
| if not expected_path.exists(): | |
| logger.warning( | |
| f"No matching {dir_name} file found for: {data_file.name} (expected in: {expected_path})" | |
| ) | |
| return False | |
| return True | |
| def _get_expected_file_path(self, dir_name: str, data_file: Path, rel_path: Path) -> Path: | |
| """Get the expected file path for a given data source.""" | |
| source_path = self.source_paths[dir_name] | |
| # For conditions, handle legacy naming where latent_X.pt maps to condition_X.pt | |
| if dir_name == "conditions" and data_file.name.startswith("latent_"): | |
| return source_path / f"condition_{data_file.stem[7:]}.pt" | |
| return source_path / rel_path | |
| def _fill_sample_data_files(self, data_file: Path, rel_path: Path, sample_files: dict[str, list[Path]]) -> None: | |
| """Add a valid sample to the sample_files tracking.""" | |
| for dir_name, output_key in self.data_sources.items(): | |
| expected_path = self._get_expected_file_path(dir_name, data_file, rel_path) | |
| sample_files[output_key].append(expected_path.relative_to(self.source_paths[dir_name])) | |
| def _validate_setup(self) -> None: | |
| """Validate that the dataset setup is correct.""" | |
| if not self.sample_files: | |
| raise ValueError("No valid samples found - all data sources must have matching files") | |
| # Verify all output keys have the same number of samples | |
| sample_counts = {key: len(files) for key, files in self.sample_files.items()} | |
| if len(set(sample_counts.values())) > 1: | |
| raise ValueError(f"Mismatched sample counts across sources: {sample_counts}") | |
| def __len__(self) -> int: | |
| # Use the first output key as reference count | |
| first_key = next(iter(self.sample_files.keys())) | |
| return len(self.sample_files[first_key]) | |
| def __getitem__(self, index: int) -> dict[str, torch.Tensor]: | |
| result = {} | |
| for dir_name, output_key in self.data_sources.items(): | |
| source_path = self.source_paths[dir_name] | |
| file_rel_path = self.sample_files[output_key][index] | |
| file_path = source_path / file_rel_path | |
| try: | |
| data = torch.load(file_path, map_location="cpu", weights_only=True) | |
| # Normalize video latent format if this is a latent source | |
| if "latent" in dir_name.lower(): | |
| data = self._normalize_video_latents(data) | |
| result[output_key] = data | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load {output_key} from {file_path}: {e}") from e | |
| # Add index for debugging | |
| result["idx"] = index | |
| return result | |
| def _normalize_video_latents(data: dict) -> dict: | |
| """ | |
| Normalize video latents to non-patchified format [C, F, H, W]. | |
| Used for keeping backward compatibility with legacy datasets. | |
| """ | |
| latents = data["latents"] | |
| # Check if latents are in legacy patchified format [seq_len, C] | |
| if latents.dim() == 2: | |
| # Legacy format: [seq_len, C] where seq_len = F * H * W | |
| num_frames = data["num_frames"] | |
| height = data["height"] | |
| width = data["width"] | |
| # Unpatchify: [seq_len, C] -> [C, F, H, W] | |
| latents = rearrange( | |
| latents, | |
| "(f h w) c -> c f h w", | |
| f=num_frames, | |
| h=height, | |
| w=width, | |
| ) | |
| # Update the data dict with unpatchified latents | |
| data = data.copy() | |
| data["latents"] = latents | |
| return data | |