File size: 4,590 Bytes
fcfea15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Local inference entrypoint for the clean JoyAI-Image release."""

from __future__ import annotations

import argparse
import os
import sys
import time
import warnings
from pathlib import Path

import torch
from PIL import Image


ROOT_DIR = Path(__file__).resolve().parent
SRC_DIR = ROOT_DIR / 'src'
if str(SRC_DIR) not in sys.path:
    sys.path.insert(0, str(SRC_DIR))

warnings.filterwarnings('ignore')


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description='Run local inference without FastAPI.')
    parser.add_argument('--ckpt-root', required=True, help='Checkpoint root.')
    parser.add_argument('--prompt', required=True, help='Edit prompt or T2I prompt.')
    parser.add_argument('--image', help='Optional input image path for image editing.')
    parser.add_argument('--output', default='example.png', help='Output image path.')
    parser.add_argument('--height', type=int, default=1024, help='Only used for text-to-image inference.')
    parser.add_argument('--width', type=int, default=1024, help='Only used for text-to-image inference.')
    parser.add_argument('--steps', type=int, default=50)
    parser.add_argument('--guidance-scale', type=float, default=5.0)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--neg-prompt', default='')
    parser.add_argument('--basesize', type=int, default=1024, help='Resize bucket base size for image editing inputs.')
    parser.add_argument('--rewrite-prompt', action='store_true')
    parser.add_argument('--config', help='Optional config path. Defaults to <ckpt-root>/infer_config.py.')
    parser.add_argument('--rewrite-model', default='gpt-5')
    parser.add_argument('--hsdp-shard-dim', type=int, help='Override config hsdp_shard_dim for multi-GPU FSDP inference.')
    return parser.parse_args()


def load_input_image(image_path: str | None) -> Image.Image | None:
    if not image_path:
        return None
    return Image.open(image_path).convert('RGB')


def is_rank0() -> bool:
    return int(os.environ.get('RANK', '0')) == 0


def resolve_device() -> torch.device:
    if not torch.cuda.is_available():
        return torch.device('cpu')
    local_rank = int(os.environ.get('LOCAL_RANK', '0'))
    torch.cuda.set_device(local_rank)
    return torch.device(f'cuda:{local_rank}')


def main() -> None:
    args = parse_args()

    from infer_runtime.model import InferenceParams, build_model
    from infer_runtime.settings import load_settings
    from modules.utils import maybe_init_distributed, clean_dist_env
    from modules.models.attention import describe_attention_backend

    dist_initialized = False
    try:
        settings = load_settings(
            ckpt_root=args.ckpt_root,
            config_path=args.config,
            rewrite_model=args.rewrite_model,
            default_seed=args.seed,
        )
        device = resolve_device()
        dist_initialized = maybe_init_distributed()

        if is_rank0():
            print(f'Chosen device: {device}')
            print(f'Attention backend: {describe_attention_backend()}')
            print(f'Config path: {settings.config_path}')
            print(f'Checkpoint path: {settings.ckpt_path}')
            if args.hsdp_shard_dim is not None:
                print(f'Override hsdp_shard_dim: {args.hsdp_shard_dim}')

        model = build_model(
            settings,
            device=device,
            hsdp_shard_dim_override=args.hsdp_shard_dim,
        )
        input_image = load_input_image(args.image)
        effective_prompt = model.maybe_rewrite_prompt(args.prompt, input_image, args.rewrite_prompt)

        start_time = time.time()
        output_image = model.infer(
            InferenceParams(
                prompt=effective_prompt,
                image=input_image,
                height=args.height,
                width=args.width,
                steps=args.steps,
                guidance_scale=args.guidance_scale,
                seed=args.seed,
                neg_prompt=args.neg_prompt,
                basesize=args.basesize,
            )
        )
        elapsed = time.time() - start_time

        if is_rank0():
            output_path = Path(args.output)
            output_path.parent.mkdir(parents=True, exist_ok=True)
            output_image.save(output_path)
            print(f'Prompt used: {effective_prompt}')
            print(f'Saved output: {output_path}')
            print(f'Time taken: {elapsed:.2f} seconds')
    finally:
        if dist_initialized:
            clean_dist_env()


if __name__ == '__main__':
    main()