File size: 3,212 Bytes
924a45b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import os
import dataclasses
from typing import Literal
from accelerate import Accelerator
from transformers import HfArgumentParser
from PIL import Image
from dst.flux.pipeline import DSTPipeline
from tqdm import tqdm
@dataclasses.dataclass
class InferenceArgs:
model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
width: int = 1024
height: int = 1024
ref_size: int = 1024
num_steps: int = 25
guidance: float = 4
seed: int = 0
only_lora: bool = True
concat_refs: bool = True
lora_rank: int = 512
pe: Literal['d', 'h', 'w', 'o'] = 'd'
def crop_if_not_square(img):
w, h = img.size
if w != h:
min_dim = min(w, h)
left = (w - min_dim) // 2
top = (h - min_dim) // 2
right = left + min_dim
bottom = top + min_dim
img = img.crop((left, top, right, bottom))
return img
def main(args: InferenceArgs):
accelerator = Accelerator()
device = accelerator.device
# test modern art images
test_cnt_folder = "./test/cnt/"
test_sty_folder = "./test/sty/"
# test real paintings
# test_cnt_folder = "./test/cnt_nga"
# test_sty_folder = "./test/sty_nga"
save_folder = "./output/"
os.makedirs(save_folder, exist_ok=True)
pipeline = DSTPipeline(
args.model_type,
device,
accelerator.state.deepspeed_plugin is not None,
only_lora=args.only_lora,
lora_rank=args.lora_rank
)
for sty_img in os.listdir(test_sty_folder):
for cnt_img in os.listdir(test_cnt_folder):
save_name = os.path.join(save_folder, f"{os.path.splitext(cnt_img)[0]}@{os.path.splitext(sty_img)[0]}.jpg")
# if os.path.exists(save_name):
# continue
cnt_path = os.path.join(test_cnt_folder, cnt_img)
sty_path = os.path.join(test_sty_folder, sty_img)
cnt_img_pil = Image.open(cnt_path).convert('RGB')
sty_img_pil = Image.open(sty_path).convert('RGB')
cnt_center_crop = crop_if_not_square(cnt_img_pil)
sty_center_crop = crop_if_not_square(sty_img_pil)
cnt_img_pil = cnt_center_crop.resize((args.width, args.height))
sty_img_pil = sty_center_crop.resize((args.width, args.height))
ref_imgs = [sty_img_pil, cnt_img_pil]
image_gen = pipeline(
prompt="",
width=args.width,
height=args.height,
guidance=args.guidance,
num_steps=args.num_steps,
seed=args.seed,
ref_imgs=ref_imgs,
pe=args.pe,
)
if args.concat_refs:
new_blank_img = Image.new('RGB', (args.width * 3, args.height))
new_blank_img.paste(cnt_img_pil, (0, 0))
new_blank_img.paste(sty_img_pil, (args.width, 0))
new_blank_img.paste(image_gen, (args.width * 2, 0))
new_blank_img.save(save_name)
if __name__ == "__main__":
parser = HfArgumentParser([InferenceArgs])
args = parser.parse_args_into_dataclasses()[0]
main(args)
|