SiniShell1's picture
Upload folder using huggingface_hub
f574a90 verified
import os
import re
import time
from dataclasses import dataclass
from glob import iglob
import argparse
import torch
from einops import rearrange
# from fire import Fire
from PIL import ExifTags, Image
from sampling import denoise, get_noise, get_schedule, prepare, unpack
from util import (configs, load_ae, load_clip,
load_flow_model, load_t5)
from transformers import pipeline
from PIL import Image
import numpy as np
import os
os.environ["FLUX_DEV"] = "/group/40034/hilljswang/flux/ckpt/flux1-dev.safetensors"
os.environ["FLUX_SCHNELL"] = "/group/40034/leizizhang/pretrained/FLUX.1-schnell/flux1-schnell.safetensors"
os.environ["AE"] = "/group/40034/hilljswang/flux/ckpt/ae.safetensors"
NSFW_THRESHOLD = 0.85
@dataclass
class SamplingOptions:
source_prompt: str
target_prompt: str
# prompt: str
width: int
height: int
num_steps: int
guidance: float
seed: int | None
@torch.inference_mode()
def encode(init_image, torch_device, ae):
init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
init_image = init_image.unsqueeze(0)
init_image = init_image.to(torch_device)
init_image = ae.encode(init_image.to()).to(torch.bfloat16)
return init_image
@torch.inference_mode()
def main(
args,
seed: int | None = None,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
num_steps: int | None = None,
loop: bool = False,
offload: bool = False,
add_sampling_metadata: bool = True,
):
"""
Sample the flux model. Either interactively (set `--loop`) or run for a
single image.
Args:
name: Name of the model to load
height: height of the sample in pixels (should be a multiple of 16)
width: width of the sample in pixels (should be a multiple of 16)
seed: Set a seed for sampling
output_name: where to save the output image, `{idx}` will be replaced
by the index of the sample
prompt: Prompt used for sampling
device: Pytorch device
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
loop: start an interactive session and sample multiple times
guidance: guidance value used for guidance distillation
add_sampling_metadata: Add the prompt to the image Exif metadata
"""
torch.set_grad_enabled(False)
name = args.name
source_prompt = args.source_prompt
target_prompt = args.target_prompt
guidance = args.guidance
output_dir = args.output_dir
num_steps = args.num_steps
# import pdb;pdb.set_trace()
# use_solver = args.use_solver
offload = args.offload
# nsfw_classifier = pipeline("image-classification", model="/group/40034/hilljswang/flux/nsfw_image_detection", device=device)
if name not in configs:
available = ", ".join(configs.keys())
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
torch_device = torch.device(device)
if num_steps is None:
num_steps = 4 if name == "flux-schnell" else 25
# init all components
t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
clip = load_clip(torch_device)
model = load_flow_model(name, device="cpu" if offload else torch_device)
ae = load_ae(name, device="cpu" if offload else torch_device)
if offload:
model.cpu()
torch.cuda.empty_cache()
ae.encoder.to(torch_device)
init_image = None
if os.path.isdir(args.source_img_dir):
for file_name in sorted(os.listdir(args.source_img_dir)):
path= os.path.join(args.source_img_dir, file_name)
if init_image is None:
init_image = np.array(Image.open(path))
width, height = init_image.shape[0], init_image.shape[1]
init_image = encode(init_image, torch_device, ae)
else:
init_image = torch.cat((init_image, encode(np.array(Image.open(path)), torch_device, ae)), dim=0)
else:
init_image = np.array(Image.open(args.source_img_dir))
shape = init_image.shape
# import pdb;pdb.set_trace()
new_h = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16
new_w = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16
init_image = init_image[:new_h, :new_w, :]
width, height = init_image.shape[0], init_image.shape[1]
init_image = encode(init_image, torch_device, ae)
# import pdb;pdb.set_trace()
rng = torch.Generator(device="cpu")
opts = SamplingOptions(
source_prompt=source_prompt,
target_prompt=target_prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
)
if loop:
opts = parse_prompt(opts)
while opts is not None:
if opts.seed is None:
opts.seed = rng.seed()
print(f"Generating with seed {opts.seed}:\n{opts.source_prompt}")
t0 = time.perf_counter()
# prepare input
# x = get_noise(
# 1,
# opts.height,
# opts.width,
# device=torch_device,
# dtype=torch.bfloat16,
# seed=opts.seed,
# )
opts.seed = None
if offload:
ae = ae.cpu()
torch.cuda.empty_cache()
t5, clip = t5.to(torch_device), clip.to(torch_device)
#############inverse#######################
info = {}
info['feature_path'] = args.feature_path
info['inject_type'] = args.inject_type
info['inject_step'] = args.inject
info['partial'] = args.partial
if not os.path.exists(args.feature_path):
os.mkdir(args.feature_path)
inp = prepare(t5, clip, init_image, prompt=opts.source_prompt)
inp_target = prepare(t5, clip, init_image, prompt=opts.target_prompt)
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
# offload TEs to CPU, load model to gpu
if offload:
t5, clip = t5.cpu(), clip.cpu()
torch.cuda.empty_cache()
model = model.to(torch_device)
# inversion initial noise
# import pdb;pdb.set_trace()
z = denoise(model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info)
# import pdb;pdb.set_trace()
inp_target["img"] = z
timesteps = get_schedule(opts.num_steps, inp_target["img"].shape[1], shift=(name != "flux-schnell"))
# denoise initial noise
x = denoise(model, **inp_target, timesteps=timesteps, guidance=guidance, inverse=False, info=info)
# offload model, load autoencoder to gpu
if offload:
model.cpu()
torch.cuda.empty_cache()
ae.decoder.to(x.device)
# decode latents to pixel space
batch_x = unpack(x.float(), opts.width, opts.height)
for x in batch_x:
x = x.unsqueeze(0)
output_name = os.path.join(output_dir, "img_{idx}.jpg")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
idx = 0
else:
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
if len(fns) > 0:
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
else:
idx = 0
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
x = ae.decode(x)
if torch.cuda.is_available():
torch.cuda.synchronize()
t1 = time.perf_counter()
fn = output_name.format(idx=idx)
print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
# bring into PIL format and save
x = x.clamp(-1, 1)
# x = embed_watermark(x.float())
x = rearrange(x[0], "c h w -> h w c")
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
# nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
img.save(fn)
# if nsfw_score < NSFW_THRESHOLD:
# exif_data = Image.Exif()
# exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
# exif_data[ExifTags.Base.Make] = "Black Forest Labs"
# exif_data[ExifTags.Base.Model] = name
# if add_sampling_metadata:
# exif_data[ExifTags.Base.ImageDescription] = source_prompt
# img.save(fn, exif=exif_data, quality=95, subsampling=0)
# idx += 1
# else:
# print("Your generated image may contain NSFW content.")
if loop:
print("-" * 80)
opts = parse_prompt(opts)
else:
opts = None
# def app():
# Fire(main)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='FLUX inference')
parser.add_argument('--name', default='flux-dev', type=str,
help='flux model')
parser.add_argument('--source_img_dir', default='', type=str,
help='flux model')
parser.add_argument('--source_prompt', type=str,
help='source prompt')
parser.add_argument('--target_prompt', type=str,
help='source prompt')
parser.add_argument('--feature_path', type=str,
help='feature_path')
parser.add_argument('--guidance', type=int, default=5,
help='guidance scale')
parser.add_argument('--num_steps', type=int, default=25,
help='num_steps')
parser.add_argument('--inject', type=int, default=20,
help='inject')
parser.add_argument('--partial', type=int, default=None,
help='partial inject')
parser.add_argument('--output_dir', default='output', type=str,
help='output dir')
parser.add_argument('--inject_type', type=str,
help='source prompt')
# parser.add_argument('--use_solver', action='store_true', help='Use solver if flag is present')
parser.add_argument('--offload', action='store_true', help='Use solver if flag is present')
args = parser.parse_args()
main(args)