sd-flow-alpha / imgsample-hacked.py
ppbrown's picture
Upload imgsample-hacked.py with huggingface_hub
b063d0f verified
#!/bin/env python
# This is a hacked front end, to use "opendiffusionai/sd-flow-alpha"
# to generate images,
# without having to use the git version of diffusers.
def argproc():
import argparse
p = argparse.ArgumentParser()
p.add_argument("--model", type=str, default="opendiffusionai/sd-flow-alpha")
p.add_argument("--seed", type=int, default=10)
p.add_argument("--steps", type=int, default=30)
p.add_argument("--prompt", nargs="+", type=str,
default="a blonde woman sitting in a cafe",
help="one or more prompt strings")
return p.parse_args()
args=argproc()
from diffusers import DiffusionPipeline
import torch.nn as nn, torch, types
import os,sys
from PIL import Image, PngImagePlugin
MODEL = args.model
print("HAND HACKING FLOWMATCH MODULE")
from diffusers import FlowMatchEulerDiscreteScheduler
def scale_model_input(self, sample, timestep):
return sample
FlowMatchEulerDiscreteScheduler.scale_model_input = scale_model_input
print(f"Loading from {MODEL}")
if MODEL.endswith(".safetensors") or MODEL.endswith(".st"):
raise ValueError("Cannot acccept single-file models. "
"Need diffusers directory tree or hf reference")
else:
pipe = DiffusionPipeline.from_pretrained(
MODEL, use_safetensors=True,
safety_checker=None, requires_safety_checker=False,
torch_dtype=torch.bfloat16,
)
pipe.safety_checker=None
print("model initialized. ")
pipe.enable_sequential_cpu_offload()
# The above obviates the need for to("cuda") I guess..
#pipe.to("cuda")
prompt=args.prompt
seed=args.seed
generator = torch.Generator(device="cuda").manual_seed(seed)
print(f"Trying render of '{prompt}' using seed {seed}...")
images = pipe(prompt, num_inference_steps=args.steps, generator=generator).images
OUTDIR=MODEL if os.path.isdir(MODEL) else "./"
for i,image in enumerate(images):
meta = PngImagePlugin.PngInfo()
meta.add_text("Comment", f"prompt={prompt}")
fname=f"{OUTDIR}/sample{i}_s{seed}.png"
print(f"saving to {fname}")
image.save(fname, pnginfo=meta)