import argparse import os import sys from pathlib import Path from typing import List PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) import torch import yaml from PIL import Image from src.models.autoencoder.pixel import PixelAE from src.models.conditioner.qwen3_text_encoder import Qwen3TextEncoder from src.models.transformer.JiT_T2I import JiT_T2I from src.diffusion.flow_matching.adam_sampling import AdamLMSamplerJiT from src.diffusion.flow_matching.scheduling import LinearScheduler from src.diffusion.base.guidance import simple_guidance_fn from src.models.autoencoder.base import fp2uint8 def read_prompts(txt_dir: str, limit: int) -> List[Path]: txt_paths = sorted(Path(txt_dir).glob("*.txt")) if not txt_paths: raise FileNotFoundError(f"No .txt files found in {txt_dir}") return txt_paths[:limit] def load_config(config_path: str): with open(config_path, "r") as f: return yaml.safe_load(f) def build_components(config: dict, device: torch.device): model_cfg = config["model"] vae = PixelAE() conditioner_cfg = model_cfg["conditioner"]["init_args"] conditioner = Qwen3TextEncoder( weight_path=conditioner_cfg["weight_path"], embed_dim=conditioner_cfg["embed_dim"], max_length=conditioner_cfg["max_length"], ) denoiser_cfg = model_cfg["denoiser"]["init_args"] denoiser = JiT_T2I( patch_size=denoiser_cfg["patch_size"], input_size=denoiser_cfg["input_size"], in_channels=denoiser_cfg["in_channels"], hidden_size=denoiser_cfg["hidden_size"], num_blocks=denoiser_cfg["num_blocks"], num_groups=denoiser_cfg["num_groups"], txt_embed_dim=denoiser_cfg["txt_embed_dim"], txt_max_length=denoiser_cfg["txt_max_length"], bottleneck_dim=denoiser_cfg["bottleneck_dim"], ) sampler_cfg = model_cfg["diffusion_sampler"]["init_args"] scheduler = LinearScheduler() sampler = AdamLMSamplerJiT( num_steps=sampler_cfg["num_steps"], guidance=sampler_cfg["guidance"], timeshift=sampler_cfg["timeshift"], order=sampler_cfg["order"], scheduler=scheduler, guidance_fn=simple_guidance_fn, ) checkpoint = torch.load(config["ckpt_path"], map_location="cpu") state_dict = checkpoint["state_dict"] denoiser_state = {} for key, value in state_dict.items(): if key.startswith("ema_denoiser."): denoiser_state[key[len("ema_denoiser."):]] = value missing, unexpected = denoiser.load_state_dict(denoiser_state, strict=False) print(f"Loaded EMA denoiser. missing={len(missing)} unexpected={len(unexpected)}") denoiser = denoiser.to(device).eval() conditioner = conditioner.to(device).eval() vae = vae.to(device).eval() sampler = sampler.to(device) return vae, conditioner, denoiser, sampler @torch.no_grad() def main(): parser = argparse.ArgumentParser(description="Generate images from the first N txt prompts in a directory.") parser.add_argument("--txt_dir", type=str, required=True, help="Directory containing .txt prompt files") parser.add_argument("--ckpt_path", type=str, required=True, help="Checkpoint path") parser.add_argument("--config", type=str, default="/media/home/lx/PixelGen-fuxian/universal_pix_t2i_workdirs/exp_sft_pix512_gan/config-fit-2605131512.yaml") parser.add_argument("--outdir", type=str, required=True, help="Output directory") parser.add_argument("--limit", type=int, default=500, help="Number of txt files to use") parser.add_argument("--batch_size", type=int, default=8, help="Batch size for generation") parser.add_argument("--seed_offset", type=int, default=0, help="Base seed offset") parser.add_argument("--device", type=str, default="cuda", help="Torch device, e.g. cuda or cuda:0") args = parser.parse_args() os.makedirs(args.outdir, exist_ok=True) config = load_config(args.config) config["ckpt_path"] = args.ckpt_path device = torch.device(args.device if torch.cuda.is_available() else "cpu") torch.set_float32_matmul_precision("high") vae, conditioner, denoiser, sampler = build_components(config, device) txt_paths = read_prompts(args.txt_dir, args.limit) prompts = [] names = [] for path in txt_paths: prompts.append(path.read_text().strip()) names.append(path.stem) latent_shape = (3, 512, 512) for start in range(0, len(prompts), args.batch_size): end = min(start + args.batch_size, len(prompts)) batch_prompts = prompts[start:end] batch_names = names[start:end] latents = [] for i in range(start, end): generator = torch.Generator().manual_seed(args.seed_offset + i) latents.append(torch.randn(latent_shape, generator=generator, dtype=torch.float32)) xT = torch.stack(latents, dim=0).to(device) if device.type == "cuda": with torch.autocast(device_type="cuda", dtype=torch.bfloat16): condition, uncondition = conditioner(batch_prompts, {}) samples = sampler(denoiser, xT, condition, uncondition) samples = vae.decode(samples) else: condition, uncondition = conditioner(batch_prompts, {}) condition = condition.float() uncondition = uncondition.float() samples = sampler(denoiser, xT, condition, uncondition) samples = vae.decode(samples) samples = fp2uint8(samples).permute(0, 2, 3, 1).cpu().numpy() for sample, name, prompt in zip(samples, batch_names, batch_prompts): image = Image.fromarray(sample) image.save(os.path.join(args.outdir, f"{name}.png")) with open(os.path.join(args.outdir, f"{name}.txt"), "w") as f: f.write(prompt) print(f"Generated {end}/{len(prompts)}") if __name__ == "__main__": main()