| import argparse |
| import os |
| import sys |
| from pathlib import Path |
| from typing import List |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parents[1] |
| if str(PROJECT_ROOT) not in sys.path: |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| import torch |
| import yaml |
| from PIL import Image |
|
|
| from src.models.autoencoder.pixel import PixelAE |
| from src.models.conditioner.qwen3_text_encoder import Qwen3TextEncoder |
| from src.models.transformer.JiT_T2I import JiT_T2I |
| from src.diffusion.flow_matching.adam_sampling import AdamLMSamplerJiT |
| from src.diffusion.flow_matching.scheduling import LinearScheduler |
| from src.diffusion.base.guidance import simple_guidance_fn |
| from src.models.autoencoder.base import fp2uint8 |
|
|
|
|
| def read_prompts(txt_dir: str, limit: int) -> List[Path]: |
| txt_paths = sorted(Path(txt_dir).glob("*.txt")) |
| if not txt_paths: |
| raise FileNotFoundError(f"No .txt files found in {txt_dir}") |
| return txt_paths[:limit] |
|
|
|
|
| def load_config(config_path: str): |
| with open(config_path, "r") as f: |
| return yaml.safe_load(f) |
|
|
|
|
| def build_components(config: dict, device: torch.device): |
| model_cfg = config["model"] |
|
|
| vae = PixelAE() |
|
|
| conditioner_cfg = model_cfg["conditioner"]["init_args"] |
| conditioner = Qwen3TextEncoder( |
| weight_path=conditioner_cfg["weight_path"], |
| embed_dim=conditioner_cfg["embed_dim"], |
| max_length=conditioner_cfg["max_length"], |
| ) |
|
|
| denoiser_cfg = model_cfg["denoiser"]["init_args"] |
| denoiser = JiT_T2I( |
| patch_size=denoiser_cfg["patch_size"], |
| input_size=denoiser_cfg["input_size"], |
| in_channels=denoiser_cfg["in_channels"], |
| hidden_size=denoiser_cfg["hidden_size"], |
| num_blocks=denoiser_cfg["num_blocks"], |
| num_groups=denoiser_cfg["num_groups"], |
| txt_embed_dim=denoiser_cfg["txt_embed_dim"], |
| txt_max_length=denoiser_cfg["txt_max_length"], |
| bottleneck_dim=denoiser_cfg["bottleneck_dim"], |
| ) |
|
|
| sampler_cfg = model_cfg["diffusion_sampler"]["init_args"] |
| scheduler = LinearScheduler() |
| sampler = AdamLMSamplerJiT( |
| num_steps=sampler_cfg["num_steps"], |
| guidance=sampler_cfg["guidance"], |
| timeshift=sampler_cfg["timeshift"], |
| order=sampler_cfg["order"], |
| scheduler=scheduler, |
| guidance_fn=simple_guidance_fn, |
| ) |
|
|
| checkpoint = torch.load(config["ckpt_path"], map_location="cpu") |
| state_dict = checkpoint["state_dict"] |
|
|
| denoiser_state = {} |
| for key, value in state_dict.items(): |
| if key.startswith("ema_denoiser."): |
| denoiser_state[key[len("ema_denoiser."):]] = value |
| missing, unexpected = denoiser.load_state_dict(denoiser_state, strict=False) |
| print(f"Loaded EMA denoiser. missing={len(missing)} unexpected={len(unexpected)}") |
|
|
| denoiser = denoiser.to(device).eval() |
| conditioner = conditioner.to(device).eval() |
| vae = vae.to(device).eval() |
| sampler = sampler.to(device) |
|
|
| return vae, conditioner, denoiser, sampler |
|
|
|
|
| @torch.no_grad() |
| def main(): |
| parser = argparse.ArgumentParser(description="Generate images from the first N txt prompts in a directory.") |
| parser.add_argument("--txt_dir", type=str, required=True, help="Directory containing .txt prompt files") |
| parser.add_argument("--ckpt_path", type=str, required=True, help="Checkpoint path") |
| parser.add_argument("--config", type=str, default="/media/home/lx/PixelGen-fuxian/universal_pix_t2i_workdirs/exp_sft_pix512_gan/config-fit-2605131512.yaml") |
| parser.add_argument("--outdir", type=str, required=True, help="Output directory") |
| parser.add_argument("--limit", type=int, default=500, help="Number of txt files to use") |
| parser.add_argument("--batch_size", type=int, default=8, help="Batch size for generation") |
| parser.add_argument("--seed_offset", type=int, default=0, help="Base seed offset") |
| parser.add_argument("--device", type=str, default="cuda", help="Torch device, e.g. cuda or cuda:0") |
| args = parser.parse_args() |
|
|
| os.makedirs(args.outdir, exist_ok=True) |
|
|
| config = load_config(args.config) |
| config["ckpt_path"] = args.ckpt_path |
| device = torch.device(args.device if torch.cuda.is_available() else "cpu") |
| torch.set_float32_matmul_precision("high") |
|
|
| vae, conditioner, denoiser, sampler = build_components(config, device) |
|
|
| txt_paths = read_prompts(args.txt_dir, args.limit) |
| prompts = [] |
| names = [] |
| for path in txt_paths: |
| prompts.append(path.read_text().strip()) |
| names.append(path.stem) |
|
|
| latent_shape = (3, 512, 512) |
|
|
| for start in range(0, len(prompts), args.batch_size): |
| end = min(start + args.batch_size, len(prompts)) |
| batch_prompts = prompts[start:end] |
| batch_names = names[start:end] |
|
|
| latents = [] |
| for i in range(start, end): |
| generator = torch.Generator().manual_seed(args.seed_offset + i) |
| latents.append(torch.randn(latent_shape, generator=generator, dtype=torch.float32)) |
| xT = torch.stack(latents, dim=0).to(device) |
|
|
| if device.type == "cuda": |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
| condition, uncondition = conditioner(batch_prompts, {}) |
| samples = sampler(denoiser, xT, condition, uncondition) |
| samples = vae.decode(samples) |
| else: |
| condition, uncondition = conditioner(batch_prompts, {}) |
| condition = condition.float() |
| uncondition = uncondition.float() |
| samples = sampler(denoiser, xT, condition, uncondition) |
| samples = vae.decode(samples) |
| samples = fp2uint8(samples).permute(0, 2, 3, 1).cpu().numpy() |
|
|
| for sample, name, prompt in zip(samples, batch_names, batch_prompts): |
| image = Image.fromarray(sample) |
| image.save(os.path.join(args.outdir, f"{name}.png")) |
| with open(os.path.join(args.outdir, f"{name}.txt"), "w") as f: |
| f.write(prompt) |
|
|
| print(f"Generated {end}/{len(prompts)}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|