linoy
inital commit
ebfc6b3
from pathlib import Path
import torch
from einops import rearrange
from torch import Tensor
from torch.utils.data import Dataset
from ltx_trainer import logger
# Constants for precomputed data directories
PRECOMPUTED_DIR_NAME = ".precomputed"
class DummyDataset(Dataset):
"""Produce random latents and prompt embeddings. For minimal demonstration and benchmarking purposes"""
def __init__(
self,
width: int = 1024,
height: int = 1024,
num_frames: int = 25,
fps: int = 24,
dataset_length: int = 200,
latent_dim: int = 128,
latent_spatial_compression_ratio: int = 32,
latent_temporal_compression_ratio: int = 8,
prompt_embed_dim: int = 4096,
prompt_sequence_length: int = 256,
) -> None:
if width % 32 != 0:
raise ValueError(f"Width must be divisible by 32, got {width=}")
if height % 32 != 0:
raise ValueError(f"Height must be divisible by 32, got {height=}")
if num_frames % 8 != 1:
raise ValueError(f"Number of frames must have a remainder of 1 when divided by 8, got {num_frames=}")
self.width = width
self.height = height
self.num_frames = num_frames
self.fps = fps
self.dataset_length = dataset_length
self.latent_dim = latent_dim
self.num_latent_frames = (num_frames - 1) // latent_temporal_compression_ratio + 1
self.latent_height = height // latent_spatial_compression_ratio
self.latent_width = width // latent_spatial_compression_ratio
self.latent_sequence_length = self.num_latent_frames * self.latent_height * self.latent_width
self.prompt_embed_dim = prompt_embed_dim
self.prompt_sequence_length = prompt_sequence_length
def __len__(self) -> int:
return self.dataset_length
def __getitem__(self, idx: int) -> dict[str, dict[str, Tensor]]:
return {
"latent_conditions": {
"latents": torch.randn(
self.latent_dim,
self.num_latent_frames,
self.latent_height,
self.latent_width,
),
"num_frames": self.num_latent_frames,
"height": self.latent_height,
"width": self.latent_width,
"fps": self.fps,
},
"text_conditions": {
"prompt_embeds": torch.randn(
self.prompt_sequence_length,
self.prompt_embed_dim,
), # random text embeddings
"prompt_attention_mask": torch.ones(
self.prompt_sequence_length,
dtype=torch.bool,
), # random attention mask
},
}
class PrecomputedDataset(Dataset):
def __init__(self, data_root: str, data_sources: dict[str, str] | list[str] | None = None) -> None:
"""
Generic dataset for loading precomputed data from multiple sources.
Args:
data_root: Root directory containing preprocessed data
data_sources: Either:
- Dict mapping directory names to output keys
- List of directory names (keys will equal values)
- None (defaults to ["latents", "conditions"])
Example:
# Standard mode (list)
dataset = PrecomputedDataset("data/", ["latents", "conditions"])
# Standard mode (dict)
dataset = PrecomputedDataset("data/", {"latents": "latent_conditions", "conditions": "text_conditions"})
# IC-LoRA mode
dataset = PrecomputedDataset("data/", ["latents", "conditions", "reference_latents"])
Note:
Latents are always returned in non-patchified format [C, F, H, W].
Legacy patchified format [seq_len, C] is automatically converted.
"""
super().__init__()
self.data_root = self._setup_data_root(data_root)
self.data_sources = self._normalize_data_sources(data_sources)
self.source_paths = self._setup_source_paths()
self.sample_files = self._discover_samples()
self._validate_setup()
@staticmethod
def _setup_data_root(data_root: str) -> Path:
"""Setup and validate the data root directory."""
data_root = Path(data_root).expanduser().resolve()
if not data_root.exists():
raise FileNotFoundError(f"Data root directory does not exist: {data_root}")
# If the given path is the dataset root, use the precomputed subdirectory
if (data_root / PRECOMPUTED_DIR_NAME).exists():
data_root = data_root / PRECOMPUTED_DIR_NAME
return data_root
@staticmethod
def _normalize_data_sources(data_sources: dict[str, str] | list[str] | None) -> dict[str, str]:
"""Normalize data_sources input to a consistent dict format."""
if data_sources is None:
# Default sources
return {"latents": "latent_conditions", "conditions": "text_conditions"}
elif isinstance(data_sources, list):
# Convert list to dict where keys equal values
return {source: source for source in data_sources}
elif isinstance(data_sources, dict):
return data_sources.copy()
else:
raise TypeError(f"data_sources must be dict, list, or None, got {type(data_sources)}")
def _setup_source_paths(self) -> dict[str, Path]:
"""Map data source names to their actual directory paths."""
source_paths = {}
for dir_name in self.data_sources:
source_path = self.data_root / dir_name
source_paths[dir_name] = source_path
# Check that all sources exist.
if not source_path.exists():
raise FileNotFoundError(f"Required {dir_name} directory does not exist: {source_path}")
return source_paths
def _discover_samples(self) -> dict[str, list[Path]]:
"""Discover all valid sample files across all data sources."""
# Use first data source as the reference to discover samples
data_key = "latents" if "latents" in self.data_sources else next(iter(self.data_sources.keys()))
data_path = self.source_paths[data_key]
data_files = list(data_path.glob("**/*.pt"))
if not data_files:
raise ValueError(f"No data files found in {data_path}")
# Initialize sample files dict
sample_files = {output_key: [] for output_key in self.data_sources.values()}
# For each data file, find corresponding files in other sources
for data_file in data_files:
rel_path = data_file.relative_to(data_path)
# Check if corresponding files exist in ALL sources
if self._all_source_files_exist(data_file, rel_path):
self._fill_sample_data_files(data_file, rel_path, sample_files)
return sample_files
def _all_source_files_exist(self, data_file: Path, rel_path: Path) -> bool:
"""Check if corresponding files exist in all data sources."""
for dir_name in self.data_sources:
expected_path = self._get_expected_file_path(dir_name, data_file, rel_path)
if not expected_path.exists():
logger.warning(
f"No matching {dir_name} file found for: {data_file.name} (expected in: {expected_path})"
)
return False
return True
def _get_expected_file_path(self, dir_name: str, data_file: Path, rel_path: Path) -> Path:
"""Get the expected file path for a given data source."""
source_path = self.source_paths[dir_name]
# For conditions, handle legacy naming where latent_X.pt maps to condition_X.pt
if dir_name == "conditions" and data_file.name.startswith("latent_"):
return source_path / f"condition_{data_file.stem[7:]}.pt"
return source_path / rel_path
def _fill_sample_data_files(self, data_file: Path, rel_path: Path, sample_files: dict[str, list[Path]]) -> None:
"""Add a valid sample to the sample_files tracking."""
for dir_name, output_key in self.data_sources.items():
expected_path = self._get_expected_file_path(dir_name, data_file, rel_path)
sample_files[output_key].append(expected_path.relative_to(self.source_paths[dir_name]))
def _validate_setup(self) -> None:
"""Validate that the dataset setup is correct."""
if not self.sample_files:
raise ValueError("No valid samples found - all data sources must have matching files")
# Verify all output keys have the same number of samples
sample_counts = {key: len(files) for key, files in self.sample_files.items()}
if len(set(sample_counts.values())) > 1:
raise ValueError(f"Mismatched sample counts across sources: {sample_counts}")
def __len__(self) -> int:
# Use the first output key as reference count
first_key = next(iter(self.sample_files.keys()))
return len(self.sample_files[first_key])
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
result = {}
for dir_name, output_key in self.data_sources.items():
source_path = self.source_paths[dir_name]
file_rel_path = self.sample_files[output_key][index]
file_path = source_path / file_rel_path
try:
data = torch.load(file_path, map_location="cpu", weights_only=True)
# Normalize video latent format if this is a latent source
if "latent" in dir_name.lower():
data = self._normalize_video_latents(data)
result[output_key] = data
except Exception as e:
raise RuntimeError(f"Failed to load {output_key} from {file_path}: {e}") from e
# Add index for debugging
result["idx"] = index
return result
@staticmethod
def _normalize_video_latents(data: dict) -> dict:
"""
Normalize video latents to non-patchified format [C, F, H, W].
Used for keeping backward compatibility with legacy datasets.
"""
latents = data["latents"]
# Check if latents are in legacy patchified format [seq_len, C]
if latents.dim() == 2:
# Legacy format: [seq_len, C] where seq_len = F * H * W
num_frames = data["num_frames"]
height = data["height"]
width = data["width"]
# Unpatchify: [seq_len, C] -> [C, F, H, W]
latents = rearrange(
latents,
"(f h w) c -> c f h w",
f=num_frames,
h=height,
w=width,
)
# Update the data dict with unpatchified latents
data = data.copy()
data["latents"] = latents
return data