English
PAINE / predictor /training /dataloader.py
joonghk's picture
first commit
03de09d
import json
import random
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Tuple
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from predictor.configs.model_dims import MODEL_DIMS, get_dims
AVAILABLE_TARGETS = ['hpsv2', 'image_reward', 'pick_score', 'clip_score']
# Model-specific embedding extraction config
# key: which key in embeds/*.pt to load
# mask_key: which key for attention mask (None = all ones)
EMBEDDING_CONFIG = {
'sdxl': {
'key': 'prompt_embeds',
'mask_key': None,
},
'dreamshaper': {
'key': 'prompt_embeds',
'mask_key': None,
},
'hunyuan_dit': {
'key': 'prompt_embeds_2',
'mask_key': 'prompt_attention_mask_2',
},
'pixart_sigma': {
'key': 'prompt_embeds',
'mask_key': 'prompt_attention_mask',
},
'sana_sprint': {
'key': 'prompt_embeds',
'mask_key': 'prompt_attention_mask',
},
}
# Override embedding keys when using a non-default text encoder.
def _extract_embeds(
embeddings: dict,
model_type: str,
embed_dim: int,
seq_len: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
config = EMBEDDING_CONFIG[model_type]
embeds = embeddings[config['key']]
if embeds.dim() == 3:
embeds = embeds.squeeze(0)
if config['mask_key'] is not None and config['mask_key'] in embeddings:
mask = embeddings[config['mask_key']]
if mask.dim() == 2:
mask = mask.squeeze(0)
else:
mask = torch.ones(embeds.shape[0], dtype=torch.long)
current_len = embeds.shape[0]
if current_len < seq_len:
pad_embeds = torch.zeros(seq_len - current_len, embeds.shape[1])
embeds = torch.cat([embeds, pad_embeds], dim=0)
pad_mask = torch.zeros(seq_len - current_len, dtype=mask.dtype)
mask = torch.cat([mask, pad_mask], dim=0)
elif current_len > seq_len:
embeds = embeds[:seq_len]
mask = mask[:seq_len]
return embeds, mask
def _load_all_metadata(data_dir: str) -> List[dict]:
data_path = Path(data_dir)
records = []
meta_files = sorted(data_path.glob("metadata*.jsonl"))
if not meta_files:
raise FileNotFoundError(f"No metadata*.jsonl found in {data_dir}")
seen = set()
for meta_file in meta_files:
with open(meta_file, 'r') as f:
for line in f:
line = line.strip()
if not line:
continue
record = json.loads(line)
key = (record['prompt_id'], record['sample_idx'])
if key not in seen:
seen.add(key)
records.append(record)
return records
class PNMDataset(Dataset):
def __init__(
self,
data_dir: str,
samples: List[dict],
model_type: str,
target: str = 'hpsv2',
y_mean: float = 0.0,
y_std: float = 1.0,
):
self.data_dir = Path(data_dir)
self.model_type = model_type
self.target = target
self.y_mean = y_mean
self.y_std = y_std
dims = get_dims(model_type)
self.embed_dim = dims['embed_dim']
self.seq_len = dims['seq_len']
self.samples = samples
unique_pids = sorted(set(r['prompt_id'] for r in samples))
self._embed_cache = {}
for pid in unique_pids:
emb_path = self.data_dir / "embeds" / f"p{pid:04d}.pt"
embeddings = torch.load(emb_path, map_location='cpu', weights_only=False)
embeds, mask = _extract_embeds(
embeddings, self.model_type, self.embed_dim, self.seq_len,
)
self._embed_cache[pid] = (embeds, mask)
# Preload all noise tensors into RAM to eliminate disk I/O during training.
self._noise_cache = {}
for rec in samples:
pid, sid = rec['prompt_id'], rec['sample_idx']
key = (pid, sid)
noise_path = self.data_dir / "noise" / f"p{pid:04d}_s{sid:02d}.pt"
noise = torch.load(noise_path, map_location='cpu', weights_only=False)
if noise.dim() == 4:
noise = noise.squeeze(0)
self._noise_cache[key] = noise
def __len__(self):
return len(self.samples)
def _get_embeddings(self, prompt_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
return self._embed_cache[prompt_id]
def __getitem__(self, idx):
record = self.samples[idx]
prompt_id = record['prompt_id']
sample_idx = record['sample_idx']
# Load noise from cache
noise = self._noise_cache[(prompt_id, sample_idx)]
prompt_embeds, prompt_mask = self._get_embeddings(prompt_id)
raw_score = float(record.get(self.target, 0.0))
normalized = (raw_score - self.y_mean) / max(self.y_std, 1e-8)
return {
'noise': noise.float(),
'prompt_embeds': prompt_embeds.float(),
'prompt_mask': prompt_mask.float(),
'prompt_id': prompt_id,
'y': torch.tensor(normalized, dtype=torch.float32),
'raw_y': torch.tensor(raw_score, dtype=torch.float32),
}
class PromptGroupedBatchSampler(torch.utils.data.Sampler):
def __init__(self, dataset: PNMDataset, k_prompts_per_batch: int, shuffle: bool = True):
self.shuffle = shuffle
self.k = k_prompts_per_batch
# Build prompt_id -> [dataset indices] mapping
self.prompt_to_indices: Dict[int, List[int]] = {}
for idx, record in enumerate(dataset.samples):
pid = record['prompt_id']
if pid not in self.prompt_to_indices:
self.prompt_to_indices[pid] = []
self.prompt_to_indices[pid].append(idx)
self.prompt_ids = list(self.prompt_to_indices.keys())
samples_per_prompt = [len(v) for v in self.prompt_to_indices.values()]
print(f" PromptGroupedBatchSampler: {len(self.prompt_ids)} prompts, "
f"~{sum(samples_per_prompt) / len(samples_per_prompt):.0f} samples/prompt, "
f"k={self.k}, batch_size={self.k * samples_per_prompt[0]}")
def __iter__(self) -> Iterator[List[int]]:
prompt_ids = self.prompt_ids.copy()
if self.shuffle:
random.shuffle(prompt_ids)
# Yield complete batches only (drop_last behavior)
for i in range(0, len(prompt_ids) - self.k + 1, self.k):
batch_prompts = prompt_ids[i:i + self.k]
batch_indices = []
for pid in batch_prompts:
batch_indices.extend(self.prompt_to_indices[pid])
yield batch_indices
def __len__(self) -> int:
return len(self.prompt_ids) // self.k
def prep_dataloaders(
data_dir: str,
model_type: str,
target: str = 'hpsv2',
split_by: str = 'prompt',
batch_size: int = 256,
num_workers: int = 4,
seed: int = 42,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
k_prompts_per_batch: int = 0,
max_prompts: int = -1,
) -> Tuple[DataLoader, DataLoader, DataLoader, Dict]:
all_records = _load_all_metadata(data_dir)
print(f"Loaded {len(all_records)} samples from metadata")
records_by_prompt = {}
for rec in all_records:
pid = rec['prompt_id']
if pid not in records_by_prompt:
records_by_prompt[pid] = []
records_by_prompt[pid].append(rec)
all_prompt_ids = sorted(records_by_prompt.keys())
if not all_prompt_ids:
raise ValueError(f"No samples found in {data_dir}")
print(f"Found {len(all_prompt_ids)} unique prompts")
# Limit number of prompts if requested
if max_prompts > 0 and max_prompts < len(all_prompt_ids):
all_prompt_ids = all_prompt_ids[:max_prompts]
all_records = [r for r in all_records if r['prompt_id'] in set(all_prompt_ids)]
print(f"Using {len(all_prompt_ids)} prompts ({len(all_records)} samples)")
rng = random.Random(seed)
shuffled_ids = all_prompt_ids.copy()
rng.shuffle(shuffled_ids)
n = len(shuffled_ids)
n_train = int(n * train_ratio)
n_val = int(n * val_ratio)
train_ids = set(shuffled_ids[:n_train])
val_ids = set(shuffled_ids[n_train:n_train + n_val])
test_ids = set(shuffled_ids[n_train + n_val:])
train_records = [r for r in all_records if r['prompt_id'] in train_ids]
val_records = [r for r in all_records if r['prompt_id'] in val_ids]
test_records = [r for r in all_records if r['prompt_id'] in test_ids]
print(f"Split: {len(train_ids)} train / {len(val_ids)} val / {len(test_ids)} test prompts")
print(f"Samples: {len(train_records)} train / {len(val_records)} val / {len(test_records)} test")
vals = np.array([float(r.get(target, 0.0)) for r in train_records])
y_mean = float(vals.mean())
y_std = float(vals.std())
print(f" {target}: mean={vals.mean():.6f}, std={vals.std():.6f}, n={len(vals)}")
stats = {
'target': target,
'y_mean': y_mean,
'y_std': y_std,
}
common_kwargs = dict(
data_dir=data_dir,
model_type=model_type,
target=target,
y_mean=y_mean,
y_std=y_std,
)
train_ds = PNMDataset(samples=train_records, **common_kwargs)
val_ds = PNMDataset(samples=val_records, **common_kwargs)
test_ds = PNMDataset(samples=test_records, **common_kwargs)
if k_prompts_per_batch > 0:
# Grouped batching: K prompts per batch, all their samples together
grouped_sampler = PromptGroupedBatchSampler(train_ds, k_prompts_per_batch, shuffle=True)
train_loader = DataLoader(
train_ds,
batch_sampler=grouped_sampler,
num_workers=num_workers,
pin_memory=True,
)
else:
# Standard global shuffle
train_loader = DataLoader(
train_ds, batch_size=batch_size, shuffle=True,
num_workers=num_workers, pin_memory=True, drop_last=True,
)
val_loader = DataLoader(
val_ds, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True,
)
test_loader = DataLoader(
test_ds, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True,
)
return train_loader, val_loader, test_loader, stats
def denormalize(pred: torch.Tensor, mean: float, std: float) -> torch.Tensor:
return pred * std + mean