File size: 3,212 Bytes
924a45b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105


import os
import dataclasses
from typing import Literal
from accelerate import Accelerator
from transformers import HfArgumentParser
from PIL import Image
from dst.flux.pipeline import DSTPipeline
from tqdm import tqdm

@dataclasses.dataclass
class InferenceArgs:
    model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
    width: int = 1024
    height: int = 1024
    ref_size: int = 1024
    num_steps: int = 25
    guidance: float = 4
    seed: int = 0
    only_lora: bool = True
    concat_refs: bool = True
    lora_rank: int = 512
    pe: Literal['d', 'h', 'w', 'o'] = 'd'



def crop_if_not_square(img):
    w, h = img.size
    if w != h:
        min_dim = min(w, h)
        left = (w - min_dim) // 2
        top = (h - min_dim) // 2
        right = left + min_dim
        bottom = top + min_dim
        img = img.crop((left, top, right, bottom))
    return img


def main(args: InferenceArgs):
    accelerator = Accelerator()
    device = accelerator.device

    # test modern art images
    test_cnt_folder = "./test/cnt/"
    test_sty_folder = "./test/sty/"
    # test real paintings
    # test_cnt_folder = "./test/cnt_nga"
    # test_sty_folder = "./test/sty_nga"
    save_folder = "./output/"
    os.makedirs(save_folder, exist_ok=True)
    
    pipeline = DSTPipeline(
        args.model_type,
        device,
        accelerator.state.deepspeed_plugin is not None,
        only_lora=args.only_lora,
        lora_rank=args.lora_rank
    )

    for sty_img in os.listdir(test_sty_folder):
        for cnt_img in os.listdir(test_cnt_folder):
            
            save_name = os.path.join(save_folder, f"{os.path.splitext(cnt_img)[0]}@{os.path.splitext(sty_img)[0]}.jpg")
            # if os.path.exists(save_name):
            #     continue
            
            cnt_path = os.path.join(test_cnt_folder, cnt_img)
            sty_path = os.path.join(test_sty_folder, sty_img)

            cnt_img_pil = Image.open(cnt_path).convert('RGB')
            sty_img_pil = Image.open(sty_path).convert('RGB')
            cnt_center_crop = crop_if_not_square(cnt_img_pil)
            sty_center_crop = crop_if_not_square(sty_img_pil)
            
            cnt_img_pil = cnt_center_crop.resize((args.width, args.height))
            sty_img_pil = sty_center_crop.resize((args.width, args.height))
            

            ref_imgs = [sty_img_pil, cnt_img_pil]

            image_gen = pipeline(
                prompt="",
                width=args.width,
                height=args.height,
                guidance=args.guidance,
                num_steps=args.num_steps,
                seed=args.seed,
                ref_imgs=ref_imgs,
                pe=args.pe,
            )

            if args.concat_refs:
                new_blank_img = Image.new('RGB', (args.width * 3, args.height))
                new_blank_img.paste(cnt_img_pil, (0, 0))
                new_blank_img.paste(sty_img_pil, (args.width, 0))
                new_blank_img.paste(image_gen, (args.width * 2, 0))

            new_blank_img.save(save_name)

if __name__ == "__main__":
    parser = HfArgumentParser([InferenceArgs])
    args = parser.parse_args_into_dataclasses()[0]
    main(args)