pixel_gen / scripts /generate_from_txt_dir.py
linxin02's picture
Upload lx_gan project
cef8b68 verified
Raw
History Blame Contribute Delete
6.02 kB
import argparse
import os
import sys
from pathlib import Path
from typing import List
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
import torch
import yaml
from PIL import Image
from src.models.autoencoder.pixel import PixelAE
from src.models.conditioner.qwen3_text_encoder import Qwen3TextEncoder
from src.models.transformer.JiT_T2I import JiT_T2I
from src.diffusion.flow_matching.adam_sampling import AdamLMSamplerJiT
from src.diffusion.flow_matching.scheduling import LinearScheduler
from src.diffusion.base.guidance import simple_guidance_fn
from src.models.autoencoder.base import fp2uint8
def read_prompts(txt_dir: str, limit: int) -> List[Path]:
txt_paths = sorted(Path(txt_dir).glob("*.txt"))
if not txt_paths:
raise FileNotFoundError(f"No .txt files found in {txt_dir}")
return txt_paths[:limit]
def load_config(config_path: str):
with open(config_path, "r") as f:
return yaml.safe_load(f)
def build_components(config: dict, device: torch.device):
model_cfg = config["model"]
vae = PixelAE()
conditioner_cfg = model_cfg["conditioner"]["init_args"]
conditioner = Qwen3TextEncoder(
weight_path=conditioner_cfg["weight_path"],
embed_dim=conditioner_cfg["embed_dim"],
max_length=conditioner_cfg["max_length"],
)
denoiser_cfg = model_cfg["denoiser"]["init_args"]
denoiser = JiT_T2I(
patch_size=denoiser_cfg["patch_size"],
input_size=denoiser_cfg["input_size"],
in_channels=denoiser_cfg["in_channels"],
hidden_size=denoiser_cfg["hidden_size"],
num_blocks=denoiser_cfg["num_blocks"],
num_groups=denoiser_cfg["num_groups"],
txt_embed_dim=denoiser_cfg["txt_embed_dim"],
txt_max_length=denoiser_cfg["txt_max_length"],
bottleneck_dim=denoiser_cfg["bottleneck_dim"],
)
sampler_cfg = model_cfg["diffusion_sampler"]["init_args"]
scheduler = LinearScheduler()
sampler = AdamLMSamplerJiT(
num_steps=sampler_cfg["num_steps"],
guidance=sampler_cfg["guidance"],
timeshift=sampler_cfg["timeshift"],
order=sampler_cfg["order"],
scheduler=scheduler,
guidance_fn=simple_guidance_fn,
)
checkpoint = torch.load(config["ckpt_path"], map_location="cpu")
state_dict = checkpoint["state_dict"]
denoiser_state = {}
for key, value in state_dict.items():
if key.startswith("ema_denoiser."):
denoiser_state[key[len("ema_denoiser."):]] = value
missing, unexpected = denoiser.load_state_dict(denoiser_state, strict=False)
print(f"Loaded EMA denoiser. missing={len(missing)} unexpected={len(unexpected)}")
denoiser = denoiser.to(device).eval()
conditioner = conditioner.to(device).eval()
vae = vae.to(device).eval()
sampler = sampler.to(device)
return vae, conditioner, denoiser, sampler
@torch.no_grad()
def main():
parser = argparse.ArgumentParser(description="Generate images from the first N txt prompts in a directory.")
parser.add_argument("--txt_dir", type=str, required=True, help="Directory containing .txt prompt files")
parser.add_argument("--ckpt_path", type=str, required=True, help="Checkpoint path")
parser.add_argument("--config", type=str, default="/media/home/lx/PixelGen-fuxian/universal_pix_t2i_workdirs/exp_sft_pix512_gan/config-fit-2605131512.yaml")
parser.add_argument("--outdir", type=str, required=True, help="Output directory")
parser.add_argument("--limit", type=int, default=500, help="Number of txt files to use")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for generation")
parser.add_argument("--seed_offset", type=int, default=0, help="Base seed offset")
parser.add_argument("--device", type=str, default="cuda", help="Torch device, e.g. cuda or cuda:0")
args = parser.parse_args()
os.makedirs(args.outdir, exist_ok=True)
config = load_config(args.config)
config["ckpt_path"] = args.ckpt_path
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
torch.set_float32_matmul_precision("high")
vae, conditioner, denoiser, sampler = build_components(config, device)
txt_paths = read_prompts(args.txt_dir, args.limit)
prompts = []
names = []
for path in txt_paths:
prompts.append(path.read_text().strip())
names.append(path.stem)
latent_shape = (3, 512, 512)
for start in range(0, len(prompts), args.batch_size):
end = min(start + args.batch_size, len(prompts))
batch_prompts = prompts[start:end]
batch_names = names[start:end]
latents = []
for i in range(start, end):
generator = torch.Generator().manual_seed(args.seed_offset + i)
latents.append(torch.randn(latent_shape, generator=generator, dtype=torch.float32))
xT = torch.stack(latents, dim=0).to(device)
if device.type == "cuda":
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
condition, uncondition = conditioner(batch_prompts, {})
samples = sampler(denoiser, xT, condition, uncondition)
samples = vae.decode(samples)
else:
condition, uncondition = conditioner(batch_prompts, {})
condition = condition.float()
uncondition = uncondition.float()
samples = sampler(denoiser, xT, condition, uncondition)
samples = vae.decode(samples)
samples = fp2uint8(samples).permute(0, 2, 3, 1).cpu().numpy()
for sample, name, prompt in zip(samples, batch_names, batch_prompts):
image = Image.fromarray(sample)
image.save(os.path.join(args.outdir, f"{name}.png"))
with open(os.path.join(args.outdir, f"{name}.txt"), "w") as f:
f.write(prompt)
print(f"Generated {end}/{len(prompts)}")
if __name__ == "__main__":
main()