File size: 13,144 Bytes
03de09d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 | """
Multi-model image generation with predictor-selected noise.
Supports: SDXL, DreamShaper, Hunyuan-DiT, PixArt-Sigma, PixArt-Alpha.
Usage:
python -m predictor.inference.generate --model_type sdxl --checkpoint experiments/sdxl_best/best_model.pth --prompt "A cat"
python -m predictor.inference.generate --model_type hunyuan_dit --checkpoint experiments/hunyuan_best/best_model.pth --prompt "A sunset"
"""
import argparse
import sys
from pathlib import Path
import torch
from predictor.inference.loader import load_predictor
from predictor.inference.noise_selection import generate_noise_candidates, select_top_k_noise
from predictor.configs.model_dims import MODEL_DIMS, get_dims
# Pipeline loading registry
PIPELINE_CONFIG = {
'sdxl': {
'class': 'StableDiffusionXLPipeline',
'pretrained': 'stabilityai/stable-diffusion-xl-base-1.0',
},
'dreamshaper': {
'class': 'StableDiffusionXLPipeline',
'pretrained': 'Lykon/dreamshaper-xl-v2-turbo',
},
'hunyuan_dit': {
'class': 'HunyuanDiTPipeline',
'pretrained': 'Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers',
},
'pixart_sigma': {
'class': 'PixArtSigmaPipeline',
'pretrained': 'PixArt-alpha/PixArt-Sigma-XL-2-1024-MS',
},
'sana_sprint': {
'class': 'SanaSprintPipeline',
'pretrained': 'Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers',
'dtype': torch.bfloat16,
},
}
def load_pipeline(model_type: str, device: str = 'cuda', dtype=torch.float16):
import diffusers
config = PIPELINE_CONFIG[model_type]
pipe_class = getattr(diffusers, config['class'])
pipe_dtype = config.get('dtype', dtype)
pipe = pipe_class.from_pretrained(
config['pretrained'],
torch_dtype=pipe_dtype,
).to(device)
return pipe
def encode_prompt_for_model(pipe, prompt: str, model_type: str, device: str = 'cuda'):
if model_type in ('sdxl', 'dreamshaper'):
# SDXL-family: encode_prompt returns (embeds, neg_embeds, pooled, neg_pooled)
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(
prompt=prompt,
prompt_2=prompt,
device=device,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
)
# For predictor: use positive embeddings only
pred_embeds = prompt_embeds # [1, 77, 2048]
pred_mask = torch.ones(pred_embeds.shape[:2], device=device, dtype=torch.long)
gen_kwargs = {
'prompt_embeds': prompt_embeds,
'negative_prompt_embeds': negative_prompt_embeds,
'pooled_prompt_embeds': pooled_prompt_embeds,
'negative_pooled_prompt_embeds': negative_pooled_prompt_embeds,
}
elif model_type == 'hunyuan_dit':
# Hunyuan-DiT: encode_prompt return signature varies by diffusers version:
# >=0.30 (newer): 8 values (CLIP_emb, neg, T5_emb, neg, mask, neg_mask, mask2, neg_mask2)
# or 4 values (CLIP_emb, neg, T5_emb, neg) without masks
# 0.36.0 (older): 4 values — CLIP only (emb, neg_emb, mask, neg_mask), no T5
result = pipe.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
)
# Detect version by checking result[2] dtype:
# T5 embeds would be float (3D tensor), CLIP mask would be int/long (2D tensor)
has_t5_from_encode = (len(result) >= 4 and result[2].dtype in (torch.float16, torch.float32, torch.bfloat16))
if has_t5_from_encode:
# Newer diffusers: encode_prompt returns both CLIP + T5
prompt_embeds, negative_prompt_embeds = result[0], result[1]
prompt_embeds_2, negative_prompt_embeds_2 = result[2], result[3]
if len(result) >= 8:
prompt_attention_mask = result[4]
negative_prompt_attention_mask = result[5]
prompt_attention_mask_2 = result[6]
negative_prompt_attention_mask_2 = result[7]
else:
prompt_attention_mask = torch.ones(prompt_embeds.shape[:2], device=device, dtype=torch.long)
negative_prompt_attention_mask = torch.ones(negative_prompt_embeds.shape[:2], device=device, dtype=torch.long)
prompt_attention_mask_2 = torch.ones(prompt_embeds_2.shape[:2], device=device, dtype=torch.long)
negative_prompt_attention_mask_2 = torch.ones(negative_prompt_embeds_2.shape[:2], device=device, dtype=torch.long)
# Use T5 embeddings for predictor (prompt_embeds_2)
pred_embeds = prompt_embeds_2 # [1, 256, 2048]
pred_mask = prompt_attention_mask_2
gen_kwargs = {
'prompt_embeds': prompt_embeds,
'negative_prompt_embeds': negative_prompt_embeds,
'prompt_embeds_2': prompt_embeds_2,
'negative_prompt_embeds_2': negative_prompt_embeds_2,
'prompt_attention_mask': prompt_attention_mask,
'negative_prompt_attention_mask': negative_prompt_attention_mask,
'prompt_attention_mask_2': prompt_attention_mask_2,
'negative_prompt_attention_mask_2': negative_prompt_attention_mask_2,
}
else:
# Older diffusers (e.g. 0.36.0): encode_prompt returns CLIP only.
# Manually encode T5 for predictor scoring.
max_seq_len = get_dims(model_type)['seq_len'] # 256
tokens = pipe.tokenizer_2(
prompt,
max_length=max_seq_len,
padding='max_length',
truncation=True,
return_tensors='pt',
).to(device)
with torch.no_grad():
t5_output = pipe.text_encoder_2(
tokens.input_ids,
attention_mask=tokens.attention_mask,
)
pred_embeds = t5_output[0].to(dtype=torch.float16) # [1, 256, 2048]
pred_mask = tokens.attention_mask # [1, 256]
# Let the pipeline handle encoding internally during generation
gen_kwargs = {}
elif model_type == 'pixart_sigma':
# PixArt: encode_prompt returns (embeds, mask, neg_embeds, neg_mask)
max_seq_len = get_dims(model_type)['seq_len']
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = pipe.encode_prompt(
prompt=prompt,
do_classifier_free_guidance=True,
num_images_per_prompt=1,
device=device,
clean_caption=False,
max_sequence_length=max_seq_len,
)
pred_embeds = prompt_embeds # [1, seq_len, embed_dim]
pred_mask = prompt_attention_mask if prompt_attention_mask is not None else \
torch.ones(pred_embeds.shape[:2], device=device, dtype=torch.long)
gen_kwargs = {
'prompt_embeds': prompt_embeds,
'prompt_attention_mask': prompt_attention_mask,
'negative_prompt_embeds': negative_prompt_embeds,
'negative_prompt_attention_mask': negative_prompt_attention_mask,
}
elif model_type == 'sana_sprint':
# SANA-Sprint: encode_prompt returns (prompt_embeds, prompt_attention_mask)
prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
prompt=prompt,
device=device,
max_sequence_length=get_dims(model_type)['seq_len'],
)
pred_embeds = prompt_embeds # [1, 300, 2304]
pred_mask = prompt_attention_mask # [1, 300]
gen_kwargs = {
'prompt_embeds': prompt_embeds,
'prompt_attention_mask': prompt_attention_mask,
}
else:
raise ValueError(f"Unknown model_type: {model_type}")
return pred_embeds, pred_mask, gen_kwargs
def main():
parser = argparse.ArgumentParser(
description="Generate images with multi-model predictor-selected noise"
)
parser.add_argument("--model_type", type=str, required=True,
choices=list(MODEL_DIMS.keys()), help="Diffusion model type")
parser.add_argument("--checkpoint", type=str, required=True,
help="Path to predictor checkpoint (.pth)")
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
parser.add_argument("--N", type=int, default=100, help="Number of noise candidates")
parser.add_argument("--B", type=int, default=4, help="Number of images to generate")
parser.add_argument("--head", type=int, default=0,
help="Prediction head (0=hpsv2, 1=image_reward, 2=clip_score)")
parser.add_argument("--steps", type=int, default=20, help="Inference steps")
parser.add_argument("--guidance-scale", type=float, default=4.5, help="CFG scale")
parser.add_argument("--seed", type=int, default=None, help="Random seed")
parser.add_argument("--output-dir", type=str, default="output", help="Output directory")
parser.add_argument("--compare", action="store_true",
help="Also generate B random baseline images")
parser.add_argument("--device", type=str, default="cuda", help="Device")
args = parser.parse_args()
dims = get_dims(args.model_type)
latent_shape = dims['latent_shape']
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"{'='*60}")
print(f"PNM Multi-Model Image Generation")
print(f"{'='*60}")
print(f" Model type: {args.model_type}")
print(f" Latent: {latent_shape}")
print(f" Prompt: {args.prompt}")
print(f" N: {args.N} noise candidates")
print(f" B: {args.B} images")
print(f" Checkpoint: {args.checkpoint}")
print(f" Output: {output_dir}")
print(f"{'='*60}")
# Load pipeline
print(f"\nLoading {args.model_type} pipeline...")
pipe = load_pipeline(args.model_type, device=args.device)
# Load predictor
print(f"Loading predictor from {args.checkpoint}...")
predictor, norm_info = load_predictor(args.checkpoint, device=args.device)
print(f" num_heads={predictor.num_heads}")
# Encode prompt
pred_embeds, pred_mask, gen_kwargs = encode_prompt_for_model(
pipe, args.prompt, args.model_type, args.device
)
# Generate and select noise
generator = torch.Generator(device=args.device).manual_seed(args.seed) if args.seed else None
noises = generate_noise_candidates(
num_candidates=args.N,
latent_shape=latent_shape,
device=args.device,
dtype=pipe.unet.dtype if hasattr(pipe, 'unet') else pipe.transformer.dtype,
generator=generator,
)
selected = select_top_k_noise(
predictor=predictor,
noises=noises,
prompt_embeds=pred_embeds,
prompt_mask=pred_mask,
num_select=args.B,
head_index=args.head,
)
# Don't pre-scale — pipeline's prepare_latents applies init_noise_sigma internally
latents = selected
# Expand embeddings for B images
B = args.B
expanded_kwargs = {}
for k, v in gen_kwargs.items():
if isinstance(v, torch.Tensor) and v.dim() >= 2:
expanded_kwargs[k] = v.expand(B, *[-1] * (v.dim() - 1))
else:
expanded_kwargs[k] = v
# Generate images
print(f"\nGenerating {B} images from top-{B} of {args.N} candidates...")
result = pipe(
prompt=None,
**expanded_kwargs,
latents=latents,
num_images_per_prompt=1,
num_inference_steps=args.steps,
guidance_scale=args.guidance_scale,
)
for i, img in enumerate(result.images):
path = output_dir / f"{args.model_type}_predictor_{i:02d}.png"
img.save(path)
print(f" Saved: {path}")
if args.compare:
print(f"\nGenerating {B} baseline images (random noise)...")
gen_random = torch.Generator(device=args.device).manual_seed(args.seed + 999) if args.seed else None
result_rand = pipe(
prompt=args.prompt,
num_images_per_prompt=B,
num_inference_steps=args.steps,
guidance_scale=args.guidance_scale,
generator=gen_random,
)
for i, img in enumerate(result_rand.images):
path = output_dir / f"{args.model_type}_random_{i:02d}.png"
img.save(path)
print(f" Saved: {path}")
print(f"\nDone! Images saved to {output_dir}/")
if __name__ == "__main__":
main()
|