import os import re import time from dataclasses import dataclass from glob import iglob import argparse import torch from einops import rearrange # from fire import Fire from PIL import ExifTags, Image from sampling import denoise, get_noise, get_schedule, prepare, unpack from util import (configs, load_ae, load_clip, load_flow_model, load_t5) from transformers import pipeline from PIL import Image import numpy as np import os os.environ["FLUX_DEV"] = "/group/40034/hilljswang/flux/ckpt/flux1-dev.safetensors" os.environ["FLUX_SCHNELL"] = "/group/40034/leizizhang/pretrained/FLUX.1-schnell/flux1-schnell.safetensors" os.environ["AE"] = "/group/40034/hilljswang/flux/ckpt/ae.safetensors" NSFW_THRESHOLD = 0.85 @dataclass class SamplingOptions: source_prompt: str target_prompt: str # prompt: str width: int height: int num_steps: int guidance: float seed: int | None @torch.inference_mode() def encode(init_image, torch_device, ae): init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1 init_image = init_image.unsqueeze(0) init_image = init_image.to(torch_device) init_image = ae.encode(init_image.to()).to(torch.bfloat16) return init_image @torch.inference_mode() def main( args, seed: int | None = None, device: str = "cuda" if torch.cuda.is_available() else "cpu", num_steps: int | None = None, loop: bool = False, offload: bool = False, add_sampling_metadata: bool = True, ): """ Sample the flux model. Either interactively (set `--loop`) or run for a single image. Args: name: Name of the model to load height: height of the sample in pixels (should be a multiple of 16) width: width of the sample in pixels (should be a multiple of 16) seed: Set a seed for sampling output_name: where to save the output image, `{idx}` will be replaced by the index of the sample prompt: Prompt used for sampling device: Pytorch device num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) loop: start an interactive session and sample multiple times guidance: guidance value used for guidance distillation add_sampling_metadata: Add the prompt to the image Exif metadata """ torch.set_grad_enabled(False) name = args.name source_prompt = args.source_prompt target_prompt = args.target_prompt guidance = args.guidance output_dir = args.output_dir num_steps = args.num_steps # import pdb;pdb.set_trace() # use_solver = args.use_solver offload = args.offload # nsfw_classifier = pipeline("image-classification", model="/group/40034/hilljswang/flux/nsfw_image_detection", device=device) if name not in configs: available = ", ".join(configs.keys()) raise ValueError(f"Got unknown model name: {name}, chose from {available}") torch_device = torch.device(device) if num_steps is None: num_steps = 4 if name == "flux-schnell" else 25 # init all components t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512) clip = load_clip(torch_device) model = load_flow_model(name, device="cpu" if offload else torch_device) ae = load_ae(name, device="cpu" if offload else torch_device) if offload: model.cpu() torch.cuda.empty_cache() ae.encoder.to(torch_device) init_image = None if os.path.isdir(args.source_img_dir): for file_name in sorted(os.listdir(args.source_img_dir)): path= os.path.join(args.source_img_dir, file_name) if init_image is None: init_image = np.array(Image.open(path)) width, height = init_image.shape[0], init_image.shape[1] init_image = encode(init_image, torch_device, ae) else: init_image = torch.cat((init_image, encode(np.array(Image.open(path)), torch_device, ae)), dim=0) else: init_image = np.array(Image.open(args.source_img_dir)) shape = init_image.shape # import pdb;pdb.set_trace() new_h = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16 new_w = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16 init_image = init_image[:new_h, :new_w, :] width, height = init_image.shape[0], init_image.shape[1] init_image = encode(init_image, torch_device, ae) # import pdb;pdb.set_trace() rng = torch.Generator(device="cpu") opts = SamplingOptions( source_prompt=source_prompt, target_prompt=target_prompt, width=width, height=height, num_steps=num_steps, guidance=guidance, seed=seed, ) if loop: opts = parse_prompt(opts) while opts is not None: if opts.seed is None: opts.seed = rng.seed() print(f"Generating with seed {opts.seed}:\n{opts.source_prompt}") t0 = time.perf_counter() # prepare input # x = get_noise( # 1, # opts.height, # opts.width, # device=torch_device, # dtype=torch.bfloat16, # seed=opts.seed, # ) opts.seed = None if offload: ae = ae.cpu() torch.cuda.empty_cache() t5, clip = t5.to(torch_device), clip.to(torch_device) #############inverse####################### info = {} info['feature_path'] = args.feature_path info['inject_type'] = args.inject_type info['inject_step'] = args.inject info['partial'] = args.partial if not os.path.exists(args.feature_path): os.mkdir(args.feature_path) inp = prepare(t5, clip, init_image, prompt=opts.source_prompt) inp_target = prepare(t5, clip, init_image, prompt=opts.target_prompt) timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) # offload TEs to CPU, load model to gpu if offload: t5, clip = t5.cpu(), clip.cpu() torch.cuda.empty_cache() model = model.to(torch_device) # inversion initial noise # import pdb;pdb.set_trace() z = denoise(model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info) # import pdb;pdb.set_trace() inp_target["img"] = z timesteps = get_schedule(opts.num_steps, inp_target["img"].shape[1], shift=(name != "flux-schnell")) # denoise initial noise x = denoise(model, **inp_target, timesteps=timesteps, guidance=guidance, inverse=False, info=info) # offload model, load autoencoder to gpu if offload: model.cpu() torch.cuda.empty_cache() ae.decoder.to(x.device) # decode latents to pixel space batch_x = unpack(x.float(), opts.width, opts.height) for x in batch_x: x = x.unsqueeze(0) output_name = os.path.join(output_dir, "img_{idx}.jpg") if not os.path.exists(output_dir): os.makedirs(output_dir) idx = 0 else: fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] if len(fns) > 0: idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 else: idx = 0 with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): x = ae.decode(x) if torch.cuda.is_available(): torch.cuda.synchronize() t1 = time.perf_counter() fn = output_name.format(idx=idx) print(f"Done in {t1 - t0:.1f}s. Saving {fn}") # bring into PIL format and save x = x.clamp(-1, 1) # x = embed_watermark(x.float()) x = rearrange(x[0], "c h w -> h w c") img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) # nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] img.save(fn) # if nsfw_score < NSFW_THRESHOLD: # exif_data = Image.Exif() # exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" # exif_data[ExifTags.Base.Make] = "Black Forest Labs" # exif_data[ExifTags.Base.Model] = name # if add_sampling_metadata: # exif_data[ExifTags.Base.ImageDescription] = source_prompt # img.save(fn, exif=exif_data, quality=95, subsampling=0) # idx += 1 # else: # print("Your generated image may contain NSFW content.") if loop: print("-" * 80) opts = parse_prompt(opts) else: opts = None # def app(): # Fire(main) if __name__ == "__main__": parser = argparse.ArgumentParser(description='FLUX inference') parser.add_argument('--name', default='flux-dev', type=str, help='flux model') parser.add_argument('--source_img_dir', default='', type=str, help='flux model') parser.add_argument('--source_prompt', type=str, help='source prompt') parser.add_argument('--target_prompt', type=str, help='source prompt') parser.add_argument('--feature_path', type=str, help='feature_path') parser.add_argument('--guidance', type=int, default=5, help='guidance scale') parser.add_argument('--num_steps', type=int, default=25, help='num_steps') parser.add_argument('--inject', type=int, default=20, help='inject') parser.add_argument('--partial', type=int, default=None, help='partial inject') parser.add_argument('--output_dir', default='output', type=str, help='output dir') parser.add_argument('--inject_type', type=str, help='source prompt') # parser.add_argument('--use_solver', action='store_true', help='Use solver if flag is present') parser.add_argument('--offload', action='store_true', help='Use solver if flag is present') args = parser.parse_args() main(args)