File size: 2,041 Bytes
12510fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import argparse

import torch
from PIL import Image

from config import CONFIG
from model import load_from_checkpoint
from inference_utils import full_inference, make_side_by_side


def run_inference(
    input_image: str,
    output_dir: str = "./output",
    ckpt_path: str = None,
    device: str = None,
    image_size: int = None,
    width_mult: float = None,
):
    if device is None:
        device = CONFIG["device"]
    if image_size is None:
        image_size = CONFIG["image_size"]
    if width_mult is None:
        width_mult = CONFIG["width_mult"]
    if ckpt_path is None:
        ckpt_path = CONFIG["checkpoint_path"]

    os.makedirs(output_dir, exist_ok=True)

    print(f"Loading model from {ckpt_path}...")
    model = load_from_checkpoint(ckpt_path, device=device, width_mult=width_mult)
    model.to(device)
    model.eval()

    img = Image.open(input_image).convert("RGB")
    base = os.path.splitext(os.path.basename(input_image))[0]

    with torch.no_grad():
        inp_img, outputs = full_inference(model, img, image_size, device)

    side_by_side = make_side_by_side(inp_img, outputs)
    out_path = os.path.join(output_dir, f"{base}_result.png")
    side_by_side.save(out_path)
    print(f"Saved result to {out_path}")

    return outputs


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Diffusion Detailer Inference")
    parser.add_argument("input", type=str, help="Path to input image")
    parser.add_argument("--output_dir", type=str, default="./output", help="Output directory")
    parser.add_argument("--ckpt_path", type=str, default=None, help="Path to checkpoint")
    parser.add_argument("--device", type=str, default=None, help="Device (cuda/cpu)")
    parser.add_argument("--image_size", type=int, default=None, help="Image size")
    args = parser.parse_args()

    run_inference(
        input_image=args.input,
        output_dir=args.output_dir,
        ckpt_path=args.ckpt_path,
        device=args.device,
        image_size=args.image_size,
    )