""" Multi-model image generation with predictor-selected noise. Supports: SDXL, DreamShaper, Hunyuan-DiT, PixArt-Sigma, PixArt-Alpha. Usage: python -m predictor.inference.generate --model_type sdxl --checkpoint experiments/sdxl_best/best_model.pth --prompt "A cat" python -m predictor.inference.generate --model_type hunyuan_dit --checkpoint experiments/hunyuan_best/best_model.pth --prompt "A sunset" """ import argparse import sys from pathlib import Path import torch from predictor.inference.loader import load_predictor from predictor.inference.noise_selection import generate_noise_candidates, select_top_k_noise from predictor.configs.model_dims import MODEL_DIMS, get_dims # Pipeline loading registry PIPELINE_CONFIG = { 'sdxl': { 'class': 'StableDiffusionXLPipeline', 'pretrained': 'stabilityai/stable-diffusion-xl-base-1.0', }, 'dreamshaper': { 'class': 'StableDiffusionXLPipeline', 'pretrained': 'Lykon/dreamshaper-xl-v2-turbo', }, 'hunyuan_dit': { 'class': 'HunyuanDiTPipeline', 'pretrained': 'Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers', }, 'pixart_sigma': { 'class': 'PixArtSigmaPipeline', 'pretrained': 'PixArt-alpha/PixArt-Sigma-XL-2-1024-MS', }, 'sana_sprint': { 'class': 'SanaSprintPipeline', 'pretrained': 'Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers', 'dtype': torch.bfloat16, }, } def load_pipeline(model_type: str, device: str = 'cuda', dtype=torch.float16): import diffusers config = PIPELINE_CONFIG[model_type] pipe_class = getattr(diffusers, config['class']) pipe_dtype = config.get('dtype', dtype) pipe = pipe_class.from_pretrained( config['pretrained'], torch_dtype=pipe_dtype, ).to(device) return pipe def encode_prompt_for_model(pipe, prompt: str, model_type: str, device: str = 'cuda'): if model_type in ('sdxl', 'dreamshaper'): # SDXL-family: encode_prompt returns (embeds, neg_embeds, pooled, neg_pooled) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = pipe.encode_prompt( prompt=prompt, prompt_2=prompt, device=device, num_images_per_prompt=1, do_classifier_free_guidance=True, ) # For predictor: use positive embeddings only pred_embeds = prompt_embeds # [1, 77, 2048] pred_mask = torch.ones(pred_embeds.shape[:2], device=device, dtype=torch.long) gen_kwargs = { 'prompt_embeds': prompt_embeds, 'negative_prompt_embeds': negative_prompt_embeds, 'pooled_prompt_embeds': pooled_prompt_embeds, 'negative_pooled_prompt_embeds': negative_pooled_prompt_embeds, } elif model_type == 'hunyuan_dit': # Hunyuan-DiT: encode_prompt return signature varies by diffusers version: # >=0.30 (newer): 8 values (CLIP_emb, neg, T5_emb, neg, mask, neg_mask, mask2, neg_mask2) # or 4 values (CLIP_emb, neg, T5_emb, neg) without masks # 0.36.0 (older): 4 values — CLIP only (emb, neg_emb, mask, neg_mask), no T5 result = pipe.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=1, do_classifier_free_guidance=True, ) # Detect version by checking result[2] dtype: # T5 embeds would be float (3D tensor), CLIP mask would be int/long (2D tensor) has_t5_from_encode = (len(result) >= 4 and result[2].dtype in (torch.float16, torch.float32, torch.bfloat16)) if has_t5_from_encode: # Newer diffusers: encode_prompt returns both CLIP + T5 prompt_embeds, negative_prompt_embeds = result[0], result[1] prompt_embeds_2, negative_prompt_embeds_2 = result[2], result[3] if len(result) >= 8: prompt_attention_mask = result[4] negative_prompt_attention_mask = result[5] prompt_attention_mask_2 = result[6] negative_prompt_attention_mask_2 = result[7] else: prompt_attention_mask = torch.ones(prompt_embeds.shape[:2], device=device, dtype=torch.long) negative_prompt_attention_mask = torch.ones(negative_prompt_embeds.shape[:2], device=device, dtype=torch.long) prompt_attention_mask_2 = torch.ones(prompt_embeds_2.shape[:2], device=device, dtype=torch.long) negative_prompt_attention_mask_2 = torch.ones(negative_prompt_embeds_2.shape[:2], device=device, dtype=torch.long) # Use T5 embeddings for predictor (prompt_embeds_2) pred_embeds = prompt_embeds_2 # [1, 256, 2048] pred_mask = prompt_attention_mask_2 gen_kwargs = { 'prompt_embeds': prompt_embeds, 'negative_prompt_embeds': negative_prompt_embeds, 'prompt_embeds_2': prompt_embeds_2, 'negative_prompt_embeds_2': negative_prompt_embeds_2, 'prompt_attention_mask': prompt_attention_mask, 'negative_prompt_attention_mask': negative_prompt_attention_mask, 'prompt_attention_mask_2': prompt_attention_mask_2, 'negative_prompt_attention_mask_2': negative_prompt_attention_mask_2, } else: # Older diffusers (e.g. 0.36.0): encode_prompt returns CLIP only. # Manually encode T5 for predictor scoring. max_seq_len = get_dims(model_type)['seq_len'] # 256 tokens = pipe.tokenizer_2( prompt, max_length=max_seq_len, padding='max_length', truncation=True, return_tensors='pt', ).to(device) with torch.no_grad(): t5_output = pipe.text_encoder_2( tokens.input_ids, attention_mask=tokens.attention_mask, ) pred_embeds = t5_output[0].to(dtype=torch.float16) # [1, 256, 2048] pred_mask = tokens.attention_mask # [1, 256] # Let the pipeline handle encoding internally during generation gen_kwargs = {} elif model_type == 'pixart_sigma': # PixArt: encode_prompt returns (embeds, mask, neg_embeds, neg_mask) max_seq_len = get_dims(model_type)['seq_len'] ( prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, ) = pipe.encode_prompt( prompt=prompt, do_classifier_free_guidance=True, num_images_per_prompt=1, device=device, clean_caption=False, max_sequence_length=max_seq_len, ) pred_embeds = prompt_embeds # [1, seq_len, embed_dim] pred_mask = prompt_attention_mask if prompt_attention_mask is not None else \ torch.ones(pred_embeds.shape[:2], device=device, dtype=torch.long) gen_kwargs = { 'prompt_embeds': prompt_embeds, 'prompt_attention_mask': prompt_attention_mask, 'negative_prompt_embeds': negative_prompt_embeds, 'negative_prompt_attention_mask': negative_prompt_attention_mask, } elif model_type == 'sana_sprint': # SANA-Sprint: encode_prompt returns (prompt_embeds, prompt_attention_mask) prompt_embeds, prompt_attention_mask = pipe.encode_prompt( prompt=prompt, device=device, max_sequence_length=get_dims(model_type)['seq_len'], ) pred_embeds = prompt_embeds # [1, 300, 2304] pred_mask = prompt_attention_mask # [1, 300] gen_kwargs = { 'prompt_embeds': prompt_embeds, 'prompt_attention_mask': prompt_attention_mask, } else: raise ValueError(f"Unknown model_type: {model_type}") return pred_embeds, pred_mask, gen_kwargs def main(): parser = argparse.ArgumentParser( description="Generate images with multi-model predictor-selected noise" ) parser.add_argument("--model_type", type=str, required=True, choices=list(MODEL_DIMS.keys()), help="Diffusion model type") parser.add_argument("--checkpoint", type=str, required=True, help="Path to predictor checkpoint (.pth)") parser.add_argument("--prompt", type=str, required=True, help="Text prompt") parser.add_argument("--N", type=int, default=100, help="Number of noise candidates") parser.add_argument("--B", type=int, default=4, help="Number of images to generate") parser.add_argument("--head", type=int, default=0, help="Prediction head (0=hpsv2, 1=image_reward, 2=clip_score)") parser.add_argument("--steps", type=int, default=20, help="Inference steps") parser.add_argument("--guidance-scale", type=float, default=4.5, help="CFG scale") parser.add_argument("--seed", type=int, default=None, help="Random seed") parser.add_argument("--output-dir", type=str, default="output", help="Output directory") parser.add_argument("--compare", action="store_true", help="Also generate B random baseline images") parser.add_argument("--device", type=str, default="cuda", help="Device") args = parser.parse_args() dims = get_dims(args.model_type) latent_shape = dims['latent_shape'] output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) print(f"{'='*60}") print(f"PNM Multi-Model Image Generation") print(f"{'='*60}") print(f" Model type: {args.model_type}") print(f" Latent: {latent_shape}") print(f" Prompt: {args.prompt}") print(f" N: {args.N} noise candidates") print(f" B: {args.B} images") print(f" Checkpoint: {args.checkpoint}") print(f" Output: {output_dir}") print(f"{'='*60}") # Load pipeline print(f"\nLoading {args.model_type} pipeline...") pipe = load_pipeline(args.model_type, device=args.device) # Load predictor print(f"Loading predictor from {args.checkpoint}...") predictor, norm_info = load_predictor(args.checkpoint, device=args.device) print(f" num_heads={predictor.num_heads}") # Encode prompt pred_embeds, pred_mask, gen_kwargs = encode_prompt_for_model( pipe, args.prompt, args.model_type, args.device ) # Generate and select noise generator = torch.Generator(device=args.device).manual_seed(args.seed) if args.seed else None noises = generate_noise_candidates( num_candidates=args.N, latent_shape=latent_shape, device=args.device, dtype=pipe.unet.dtype if hasattr(pipe, 'unet') else pipe.transformer.dtype, generator=generator, ) selected = select_top_k_noise( predictor=predictor, noises=noises, prompt_embeds=pred_embeds, prompt_mask=pred_mask, num_select=args.B, head_index=args.head, ) # Don't pre-scale — pipeline's prepare_latents applies init_noise_sigma internally latents = selected # Expand embeddings for B images B = args.B expanded_kwargs = {} for k, v in gen_kwargs.items(): if isinstance(v, torch.Tensor) and v.dim() >= 2: expanded_kwargs[k] = v.expand(B, *[-1] * (v.dim() - 1)) else: expanded_kwargs[k] = v # Generate images print(f"\nGenerating {B} images from top-{B} of {args.N} candidates...") result = pipe( prompt=None, **expanded_kwargs, latents=latents, num_images_per_prompt=1, num_inference_steps=args.steps, guidance_scale=args.guidance_scale, ) for i, img in enumerate(result.images): path = output_dir / f"{args.model_type}_predictor_{i:02d}.png" img.save(path) print(f" Saved: {path}") if args.compare: print(f"\nGenerating {B} baseline images (random noise)...") gen_random = torch.Generator(device=args.device).manual_seed(args.seed + 999) if args.seed else None result_rand = pipe( prompt=args.prompt, num_images_per_prompt=B, num_inference_steps=args.steps, guidance_scale=args.guidance_scale, generator=gen_random, ) for i, img in enumerate(result_rand.images): path = output_dir / f"{args.model_type}_random_{i:02d}.png" img.save(path) print(f" Saved: {path}") print(f"\nDone! Images saved to {output_dir}/") if __name__ == "__main__": main()