linxin02's picture
Open-source PixelControl code (relative paths, identity scrubbed)
497c818 verified
Raw
History Blame Contribute Delete
12 kB
"""Minimal extracted datasets for single-control and three-control training."""
from __future__ import annotations
import json
import os
import random
from pathlib import Path
from typing import Optional, Sequence
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image, ImageFile
from torch.utils.data import Dataset
from torchvision.transforms import CenterCrop, Normalize, Resize
from torchvision.transforms.functional import to_tensor
ImageFile.LOAD_TRUNCATED_IMAGES = False
IMAGE_EXTS = (".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG")
DEPTH_EXTS = (".depth.npy", ".npy", ".depth.png", ".png", ".depth.jpg", ".jpg", ".depth.jpeg", ".jpeg")
SEG_EXTS = (".sam2_label.npy", ".sam2_label.png", ".png", ".npy")
EDGE_EXTS = (".edge.npy", ".npy", ".edge.png", ".png", ".edge.jpg", ".jpg", ".edge.jpeg", ".jpeg")
def strip_image_ext(filename: str) -> str:
for ext in IMAGE_EXTS:
if filename.endswith(ext):
return filename[: -len(ext)]
return os.path.splitext(filename)[0]
def find_with_exts(root: str | Path, stem: str, exts: Sequence[str]) -> str | None:
root = str(root)
for ext in exts:
path = os.path.join(root, stem + ext)
if os.path.exists(path):
return path
return None
def read_caption(path: str, default: str = "") -> str:
if not path or not os.path.exists(path):
return default
with open(path, "r", encoding="utf-8", errors="ignore") as f:
text = f.read().strip()
return text or default
def _resize_crop_1ch(x: np.ndarray, target_size: int, mode: str) -> torch.Tensor:
x_t = torch.from_numpy(x.astype(np.float32)).unsqueeze(0).unsqueeze(0)
h, w = x_t.shape[-2:]
short = min(h, w)
scale = float(target_size) / float(short)
new_h, new_w = int(round(h * scale)), int(round(w * scale))
x_t = F.interpolate(x_t, size=(new_h, new_w), mode=mode, align_corners=False if mode == "bilinear" else None)
top = (new_h - target_size) // 2
left = (new_w - target_size) // 2
return x_t[:, :, top:top + target_size, left:left + target_size].squeeze(0)
def load_depth_to_tensor(path: str, target_size: int, normalize: bool = True, invert_depth: bool = False) -> torch.Tensor:
ext = os.path.splitext(path)[1].lower()
if ext == ".npy":
depth = np.load(path).astype(np.float32)
elif ext == ".npz":
archive = np.load(path)
depth = archive[list(archive.keys())[0]].astype(np.float32)
else:
with Image.open(path) as im:
im = im.convert("I") if im.mode in ("I", "I;16") else im.convert("L")
depth = np.asarray(im, dtype=np.float32)
if depth.ndim == 3:
depth = depth.mean(axis=-1)
out = _resize_crop_1ch(depth, target_size, mode="bilinear")
if normalize:
lo, hi = out.min(), out.max()
out = (out - lo) / (hi - lo).clamp_min(1e-6)
if invert_depth:
out = 1.0 - out
return out.clamp_(0.0, 1.0)
def load_seg_to_tensor(path: str, target_size: int, normalize: bool = True) -> torch.Tensor:
ext = os.path.splitext(path)[1].lower()
if ext == ".npy":
seg = np.load(path)
else:
with Image.open(path) as im:
seg = np.asarray(im.convert("L"))
if seg.ndim == 3:
seg = seg[..., 0]
out = _resize_crop_1ch(seg.astype(np.float32), target_size, mode="nearest")
if normalize:
max_id = out.max()
if max_id.item() > 0:
out = out / max_id
out = out.clamp_(0.0, 1.0)
return out
def load_edge_to_tensor(path: str, target_size: int) -> torch.Tensor:
ext = os.path.splitext(path)[1].lower()
if ext == ".npy":
edge = np.load(path).astype(np.float32)
else:
with Image.open(path) as im:
edge = np.asarray(im.convert("L"), dtype=np.float32)
if edge.ndim == 3:
edge = edge.mean(axis=-1)
out = _resize_crop_1ch(edge, target_size, mode="bilinear")
lo, hi = out.min(), out.max()
out = (out - lo) / (hi - lo).clamp_min(1e-6)
return out.clamp_(0.0, 1.0)
def subdir_range(start: int, end: int) -> list[str]:
return [f"sa_{i:06d}" for i in range(int(start), int(end) + 1)]
class PixelThreeControlDataset(Dataset):
"""Paired RGB/caption/depth/seg/edge dataset.
Returns a dict ready for a PixelDiT-like training loop. The loop can sample
active modes and zero inactive channels using `apply_multi_control_mode`.
"""
def __init__(
self,
image_root: str,
depth_root: str,
seg_root: str,
edge_root: str,
resolution: int = 512,
subdirs: Optional[Sequence[str]] = None,
cache_index_path: str | None = None,
max_retries: int = 20,
seg_normalize: bool = True,
require_caption: bool = True,
):
self.image_root = image_root
self.depth_root = depth_root
self.seg_root = seg_root
self.edge_root = edge_root
self.resolution = int(resolution)
self.subdirs = list(subdirs) if subdirs is not None else None
self.max_retries = int(max_retries)
self.seg_normalize = bool(seg_normalize)
self.require_caption = bool(require_caption)
self.samples: list[dict] = []
if cache_index_path and os.path.exists(cache_index_path):
self.samples = json.load(open(cache_index_path, "r", encoding="utf-8"))
else:
self._build_index()
if cache_index_path:
Path(cache_index_path).parent.mkdir(parents=True, exist_ok=True)
json.dump(self.samples, open(cache_index_path, "w", encoding="utf-8"))
self.resize = Resize(self.resolution)
self.center_crop = CenterCrop(self.resolution)
self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
def _iter_subdirs(self):
if self.subdirs is not None:
return self.subdirs
return sorted(p.name for p in Path(self.image_root).iterdir() if p.is_dir())
def _build_index(self):
for sub in self._iter_subdirs():
image_dir = Path(self.image_root) / sub
depth_dir = Path(self.depth_root) / sub
seg_dir = Path(self.seg_root) / sub
edge_dir = Path(self.edge_root) / sub
if not image_dir.is_dir() or not depth_dir.is_dir() or not seg_dir.is_dir() or not edge_dir.is_dir():
continue
for cap in sorted(image_dir.glob("*.txt")):
stem = cap.stem
image_path = find_with_exts(image_dir, stem, IMAGE_EXTS)
depth_path = find_with_exts(depth_dir, stem, DEPTH_EXTS)
seg_path = find_with_exts(seg_dir, stem, SEG_EXTS)
edge_path = find_with_exts(edge_dir, stem, EDGE_EXTS)
if image_path and depth_path and seg_path and edge_path:
self.samples.append(
{
"stem": stem,
"image_path": image_path,
"caption_path": str(cap),
"depth_path": depth_path,
"seg_path": seg_path,
"edge_path": edge_path,
}
)
def __len__(self):
return len(self.samples)
def _build_item(self, idx: int):
sample = self.samples[idx]
pil = Image.open(sample["image_path"]).convert("RGB")
pil = self.center_crop(self.resize(pil))
image_01 = to_tensor(pil)
image_m11 = self.normalize(image_01)
depth = load_depth_to_tensor(sample["depth_path"], self.resolution)
seg = load_seg_to_tensor(sample["seg_path"], self.resolution, normalize=self.seg_normalize)
edge = load_edge_to_tensor(sample["edge_path"], self.resolution)
control = torch.cat([depth, seg, edge], dim=0)
return {
"image": image_m11,
"caption": read_caption(sample["caption_path"]),
"control": control,
"control_keep": torch.tensor([1.0, 1.0, 1.0], dtype=torch.float32),
"control_mode": "depth_seg_edge",
"depth": depth,
"seg": seg,
"edge": edge,
**sample,
}
def __getitem__(self, idx: int):
cur = int(idx)
for _ in range(self.max_retries):
try:
return self._build_item(cur)
except Exception as exc:
nxt = random.randint(0, len(self.samples) - 1)
print(f"[PixelThreeControlDataset] bad sample idx={cur}: {exc!r}; retry idx={nxt}")
cur = nxt
raise RuntimeError(f"failed to load valid sample after {self.max_retries} retries")
class PixelSingleControlDataset(Dataset):
"""Single-control depth/seg/edge dataset for baseline training."""
def __init__(
self,
image_root: str,
control_root: str,
control_type: str,
resolution: int = 512,
subdirs: Optional[Sequence[str]] = None,
seg_normalize: bool = True,
):
if control_type not in {"depth", "seg", "edge"}:
raise ValueError("control_type must be depth, seg, or edge")
self.image_root = image_root
self.control_root = control_root
self.control_type = control_type
self.resolution = int(resolution)
self.subdirs = list(subdirs) if subdirs is not None else None
self.seg_normalize = bool(seg_normalize)
self.samples: list[dict] = []
self._build_index()
self.resize = Resize(self.resolution)
self.center_crop = CenterCrop(self.resolution)
self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
def _iter_subdirs(self):
if self.subdirs is not None:
return self.subdirs
return sorted(p.name for p in Path(self.image_root).iterdir() if p.is_dir())
def _find_control(self, control_dir: Path, stem: str):
if self.control_type == "depth":
return find_with_exts(control_dir, stem, DEPTH_EXTS)
if self.control_type == "seg":
return find_with_exts(control_dir, stem, SEG_EXTS)
return find_with_exts(control_dir, stem, EDGE_EXTS)
def _build_index(self):
for sub in self._iter_subdirs():
image_dir = Path(self.image_root) / sub
control_dir = Path(self.control_root) / sub
if not image_dir.is_dir() or not control_dir.is_dir():
continue
for cap in sorted(image_dir.glob("*.txt")):
stem = cap.stem
image_path = find_with_exts(image_dir, stem, IMAGE_EXTS)
control_path = self._find_control(control_dir, stem)
if image_path and control_path:
self.samples.append(
{"stem": stem, "image_path": image_path, "caption_path": str(cap), "control_path": control_path}
)
def __len__(self):
return len(self.samples)
def _load_control(self, path: str):
if self.control_type == "depth":
return load_depth_to_tensor(path, self.resolution)
if self.control_type == "seg":
return load_seg_to_tensor(path, self.resolution, normalize=self.seg_normalize)
return load_edge_to_tensor(path, self.resolution)
def __getitem__(self, idx: int):
sample = self.samples[int(idx)]
pil = Image.open(sample["image_path"]).convert("RGB")
pil = self.center_crop(self.resize(pil))
image_m11 = self.normalize(to_tensor(pil))
control = self._load_control(sample["control_path"])
return {
"image": image_m11,
"caption": read_caption(sample["caption_path"]),
"control": control,
"control_keep": torch.tensor([1.0], dtype=torch.float32),
"control_mode": self.control_type,
self.control_type: control,
**sample,
}