File size: 2,130 Bytes
63e10a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
JoyAI-Image Edit Plus — Diffusers Inference Script

Multi-image instruction-guided editing: provide reference images and a text
instruction to generate a new image that combines elements from the references.
"""

import argparse

import torch
from PIL import Image

from diffusers import JoyImageEditPlusPipeline


def parse_args():
    parser = argparse.ArgumentParser(description="JoyAI-Image Edit Plus inference")
    parser.add_argument(
        "--model_path",
        type=str,
        default="jdopensource/JoyAI-Image-Edit-Plus-Diffusers",
        help="Path or HuggingFace repo ID of the model",
    )
    parser.add_argument(
        "--images",
        type=str,
        nargs="+",
        required=True,
        help="Paths to reference images (1-6 images)",
    )
    parser.add_argument(
        "--prompt",
        type=str,
        required=True,
        help="Text instruction for editing",
    )
    parser.add_argument(
        "--negative_prompt",
        type=str,
        default="low quality, blurry, deformed",
        help="Negative prompt",
    )
    parser.add_argument("--num_inference_steps", type=int, default=30)
    parser.add_argument("--guidance_scale", type=float, default=4.0)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--output", type=str, default="output.png")
    return parser.parse_args()


def main():
    args = parse_args()

    pipe = JoyImageEditPlusPipeline.from_pretrained(
        args.model_path,
        torch_dtype=torch.bfloat16,
    ).to("cuda")

    images = [Image.open(p).convert("RGB") for p in args.images]

    target_h, target_w = pipe._get_bucket_size(images[-1])

    result = pipe(
        images=images,
        prompt=args.prompt,
        negative_prompt=args.negative_prompt,
        height=target_h,
        width=target_w,
        num_inference_steps=args.num_inference_steps,
        guidance_scale=args.guidance_scale,
        generator=torch.Generator(device="cuda").manual_seed(args.seed),
    )

    result.images[0].save(args.output)
    print(f"Saved to {args.output}")


if __name__ == "__main__":
    main()