| """Minimal extracted datasets for single-control and three-control training.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import random |
| from pathlib import Path |
| from typing import Optional, Sequence |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from PIL import Image, ImageFile |
| from torch.utils.data import Dataset |
| from torchvision.transforms import CenterCrop, Normalize, Resize |
| from torchvision.transforms.functional import to_tensor |
|
|
| ImageFile.LOAD_TRUNCATED_IMAGES = False |
|
|
| IMAGE_EXTS = (".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG") |
| DEPTH_EXTS = (".depth.npy", ".npy", ".depth.png", ".png", ".depth.jpg", ".jpg", ".depth.jpeg", ".jpeg") |
| SEG_EXTS = (".sam2_label.npy", ".sam2_label.png", ".png", ".npy") |
| EDGE_EXTS = (".edge.npy", ".npy", ".edge.png", ".png", ".edge.jpg", ".jpg", ".edge.jpeg", ".jpeg") |
|
|
|
|
| def strip_image_ext(filename: str) -> str: |
| for ext in IMAGE_EXTS: |
| if filename.endswith(ext): |
| return filename[: -len(ext)] |
| return os.path.splitext(filename)[0] |
|
|
|
|
| def find_with_exts(root: str | Path, stem: str, exts: Sequence[str]) -> str | None: |
| root = str(root) |
| for ext in exts: |
| path = os.path.join(root, stem + ext) |
| if os.path.exists(path): |
| return path |
| return None |
|
|
|
|
| def read_caption(path: str, default: str = "") -> str: |
| if not path or not os.path.exists(path): |
| return default |
| with open(path, "r", encoding="utf-8", errors="ignore") as f: |
| text = f.read().strip() |
| return text or default |
|
|
|
|
| def _resize_crop_1ch(x: np.ndarray, target_size: int, mode: str) -> torch.Tensor: |
| x_t = torch.from_numpy(x.astype(np.float32)).unsqueeze(0).unsqueeze(0) |
| h, w = x_t.shape[-2:] |
| short = min(h, w) |
| scale = float(target_size) / float(short) |
| new_h, new_w = int(round(h * scale)), int(round(w * scale)) |
| x_t = F.interpolate(x_t, size=(new_h, new_w), mode=mode, align_corners=False if mode == "bilinear" else None) |
| top = (new_h - target_size) // 2 |
| left = (new_w - target_size) // 2 |
| return x_t[:, :, top:top + target_size, left:left + target_size].squeeze(0) |
|
|
|
|
| def load_depth_to_tensor(path: str, target_size: int, normalize: bool = True, invert_depth: bool = False) -> torch.Tensor: |
| ext = os.path.splitext(path)[1].lower() |
| if ext == ".npy": |
| depth = np.load(path).astype(np.float32) |
| elif ext == ".npz": |
| archive = np.load(path) |
| depth = archive[list(archive.keys())[0]].astype(np.float32) |
| else: |
| with Image.open(path) as im: |
| im = im.convert("I") if im.mode in ("I", "I;16") else im.convert("L") |
| depth = np.asarray(im, dtype=np.float32) |
| if depth.ndim == 3: |
| depth = depth.mean(axis=-1) |
| out = _resize_crop_1ch(depth, target_size, mode="bilinear") |
| if normalize: |
| lo, hi = out.min(), out.max() |
| out = (out - lo) / (hi - lo).clamp_min(1e-6) |
| if invert_depth: |
| out = 1.0 - out |
| return out.clamp_(0.0, 1.0) |
|
|
|
|
| def load_seg_to_tensor(path: str, target_size: int, normalize: bool = True) -> torch.Tensor: |
| ext = os.path.splitext(path)[1].lower() |
| if ext == ".npy": |
| seg = np.load(path) |
| else: |
| with Image.open(path) as im: |
| seg = np.asarray(im.convert("L")) |
| if seg.ndim == 3: |
| seg = seg[..., 0] |
| out = _resize_crop_1ch(seg.astype(np.float32), target_size, mode="nearest") |
| if normalize: |
| max_id = out.max() |
| if max_id.item() > 0: |
| out = out / max_id |
| out = out.clamp_(0.0, 1.0) |
| return out |
|
|
|
|
| def load_edge_to_tensor(path: str, target_size: int) -> torch.Tensor: |
| ext = os.path.splitext(path)[1].lower() |
| if ext == ".npy": |
| edge = np.load(path).astype(np.float32) |
| else: |
| with Image.open(path) as im: |
| edge = np.asarray(im.convert("L"), dtype=np.float32) |
| if edge.ndim == 3: |
| edge = edge.mean(axis=-1) |
| out = _resize_crop_1ch(edge, target_size, mode="bilinear") |
| lo, hi = out.min(), out.max() |
| out = (out - lo) / (hi - lo).clamp_min(1e-6) |
| return out.clamp_(0.0, 1.0) |
|
|
|
|
| def subdir_range(start: int, end: int) -> list[str]: |
| return [f"sa_{i:06d}" for i in range(int(start), int(end) + 1)] |
|
|
|
|
| class PixelThreeControlDataset(Dataset): |
| """Paired RGB/caption/depth/seg/edge dataset. |
| |
| Returns a dict ready for a PixelDiT-like training loop. The loop can sample |
| active modes and zero inactive channels using `apply_multi_control_mode`. |
| """ |
|
|
| def __init__( |
| self, |
| image_root: str, |
| depth_root: str, |
| seg_root: str, |
| edge_root: str, |
| resolution: int = 512, |
| subdirs: Optional[Sequence[str]] = None, |
| cache_index_path: str | None = None, |
| max_retries: int = 20, |
| seg_normalize: bool = True, |
| require_caption: bool = True, |
| ): |
| self.image_root = image_root |
| self.depth_root = depth_root |
| self.seg_root = seg_root |
| self.edge_root = edge_root |
| self.resolution = int(resolution) |
| self.subdirs = list(subdirs) if subdirs is not None else None |
| self.max_retries = int(max_retries) |
| self.seg_normalize = bool(seg_normalize) |
| self.require_caption = bool(require_caption) |
| self.samples: list[dict] = [] |
| if cache_index_path and os.path.exists(cache_index_path): |
| self.samples = json.load(open(cache_index_path, "r", encoding="utf-8")) |
| else: |
| self._build_index() |
| if cache_index_path: |
| Path(cache_index_path).parent.mkdir(parents=True, exist_ok=True) |
| json.dump(self.samples, open(cache_index_path, "w", encoding="utf-8")) |
| self.resize = Resize(self.resolution) |
| self.center_crop = CenterCrop(self.resolution) |
| self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
|
| def _iter_subdirs(self): |
| if self.subdirs is not None: |
| return self.subdirs |
| return sorted(p.name for p in Path(self.image_root).iterdir() if p.is_dir()) |
|
|
| def _build_index(self): |
| for sub in self._iter_subdirs(): |
| image_dir = Path(self.image_root) / sub |
| depth_dir = Path(self.depth_root) / sub |
| seg_dir = Path(self.seg_root) / sub |
| edge_dir = Path(self.edge_root) / sub |
| if not image_dir.is_dir() or not depth_dir.is_dir() or not seg_dir.is_dir() or not edge_dir.is_dir(): |
| continue |
| for cap in sorted(image_dir.glob("*.txt")): |
| stem = cap.stem |
| image_path = find_with_exts(image_dir, stem, IMAGE_EXTS) |
| depth_path = find_with_exts(depth_dir, stem, DEPTH_EXTS) |
| seg_path = find_with_exts(seg_dir, stem, SEG_EXTS) |
| edge_path = find_with_exts(edge_dir, stem, EDGE_EXTS) |
| if image_path and depth_path and seg_path and edge_path: |
| self.samples.append( |
| { |
| "stem": stem, |
| "image_path": image_path, |
| "caption_path": str(cap), |
| "depth_path": depth_path, |
| "seg_path": seg_path, |
| "edge_path": edge_path, |
| } |
| ) |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def _build_item(self, idx: int): |
| sample = self.samples[idx] |
| pil = Image.open(sample["image_path"]).convert("RGB") |
| pil = self.center_crop(self.resize(pil)) |
| image_01 = to_tensor(pil) |
| image_m11 = self.normalize(image_01) |
| depth = load_depth_to_tensor(sample["depth_path"], self.resolution) |
| seg = load_seg_to_tensor(sample["seg_path"], self.resolution, normalize=self.seg_normalize) |
| edge = load_edge_to_tensor(sample["edge_path"], self.resolution) |
| control = torch.cat([depth, seg, edge], dim=0) |
| return { |
| "image": image_m11, |
| "caption": read_caption(sample["caption_path"]), |
| "control": control, |
| "control_keep": torch.tensor([1.0, 1.0, 1.0], dtype=torch.float32), |
| "control_mode": "depth_seg_edge", |
| "depth": depth, |
| "seg": seg, |
| "edge": edge, |
| **sample, |
| } |
|
|
| def __getitem__(self, idx: int): |
| cur = int(idx) |
| for _ in range(self.max_retries): |
| try: |
| return self._build_item(cur) |
| except Exception as exc: |
| nxt = random.randint(0, len(self.samples) - 1) |
| print(f"[PixelThreeControlDataset] bad sample idx={cur}: {exc!r}; retry idx={nxt}") |
| cur = nxt |
| raise RuntimeError(f"failed to load valid sample after {self.max_retries} retries") |
|
|
|
|
| class PixelSingleControlDataset(Dataset): |
| """Single-control depth/seg/edge dataset for baseline training.""" |
|
|
| def __init__( |
| self, |
| image_root: str, |
| control_root: str, |
| control_type: str, |
| resolution: int = 512, |
| subdirs: Optional[Sequence[str]] = None, |
| seg_normalize: bool = True, |
| ): |
| if control_type not in {"depth", "seg", "edge"}: |
| raise ValueError("control_type must be depth, seg, or edge") |
| self.image_root = image_root |
| self.control_root = control_root |
| self.control_type = control_type |
| self.resolution = int(resolution) |
| self.subdirs = list(subdirs) if subdirs is not None else None |
| self.seg_normalize = bool(seg_normalize) |
| self.samples: list[dict] = [] |
| self._build_index() |
| self.resize = Resize(self.resolution) |
| self.center_crop = CenterCrop(self.resolution) |
| self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
|
| def _iter_subdirs(self): |
| if self.subdirs is not None: |
| return self.subdirs |
| return sorted(p.name for p in Path(self.image_root).iterdir() if p.is_dir()) |
|
|
| def _find_control(self, control_dir: Path, stem: str): |
| if self.control_type == "depth": |
| return find_with_exts(control_dir, stem, DEPTH_EXTS) |
| if self.control_type == "seg": |
| return find_with_exts(control_dir, stem, SEG_EXTS) |
| return find_with_exts(control_dir, stem, EDGE_EXTS) |
|
|
| def _build_index(self): |
| for sub in self._iter_subdirs(): |
| image_dir = Path(self.image_root) / sub |
| control_dir = Path(self.control_root) / sub |
| if not image_dir.is_dir() or not control_dir.is_dir(): |
| continue |
| for cap in sorted(image_dir.glob("*.txt")): |
| stem = cap.stem |
| image_path = find_with_exts(image_dir, stem, IMAGE_EXTS) |
| control_path = self._find_control(control_dir, stem) |
| if image_path and control_path: |
| self.samples.append( |
| {"stem": stem, "image_path": image_path, "caption_path": str(cap), "control_path": control_path} |
| ) |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def _load_control(self, path: str): |
| if self.control_type == "depth": |
| return load_depth_to_tensor(path, self.resolution) |
| if self.control_type == "seg": |
| return load_seg_to_tensor(path, self.resolution, normalize=self.seg_normalize) |
| return load_edge_to_tensor(path, self.resolution) |
|
|
| def __getitem__(self, idx: int): |
| sample = self.samples[int(idx)] |
| pil = Image.open(sample["image_path"]).convert("RGB") |
| pil = self.center_crop(self.resize(pil)) |
| image_m11 = self.normalize(to_tensor(pil)) |
| control = self._load_control(sample["control_path"]) |
| return { |
| "image": image_m11, |
| "caption": read_caption(sample["caption_path"]), |
| "control": control, |
| "control_keep": torch.tensor([1.0], dtype=torch.float32), |
| "control_mode": self.control_type, |
| self.control_type: control, |
| **sample, |
| } |
|
|