JoyAI-Image-Edit-Space / inference.py
stevengrove's picture
Initial commit with Xet-tracked image assets
fcfea15
"""Local inference entrypoint for the clean JoyAI-Image release."""
from __future__ import annotations
import argparse
import os
import sys
import time
import warnings
from pathlib import Path
import torch
from PIL import Image
ROOT_DIR = Path(__file__).resolve().parent
SRC_DIR = ROOT_DIR / 'src'
if str(SRC_DIR) not in sys.path:
sys.path.insert(0, str(SRC_DIR))
warnings.filterwarnings('ignore')
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='Run local inference without FastAPI.')
parser.add_argument('--ckpt-root', required=True, help='Checkpoint root.')
parser.add_argument('--prompt', required=True, help='Edit prompt or T2I prompt.')
parser.add_argument('--image', help='Optional input image path for image editing.')
parser.add_argument('--output', default='example.png', help='Output image path.')
parser.add_argument('--height', type=int, default=1024, help='Only used for text-to-image inference.')
parser.add_argument('--width', type=int, default=1024, help='Only used for text-to-image inference.')
parser.add_argument('--steps', type=int, default=50)
parser.add_argument('--guidance-scale', type=float, default=5.0)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--neg-prompt', default='')
parser.add_argument('--basesize', type=int, default=1024, help='Resize bucket base size for image editing inputs.')
parser.add_argument('--rewrite-prompt', action='store_true')
parser.add_argument('--config', help='Optional config path. Defaults to <ckpt-root>/infer_config.py.')
parser.add_argument('--rewrite-model', default='gpt-5')
parser.add_argument('--hsdp-shard-dim', type=int, help='Override config hsdp_shard_dim for multi-GPU FSDP inference.')
return parser.parse_args()
def load_input_image(image_path: str | None) -> Image.Image | None:
if not image_path:
return None
return Image.open(image_path).convert('RGB')
def is_rank0() -> bool:
return int(os.environ.get('RANK', '0')) == 0
def resolve_device() -> torch.device:
if not torch.cuda.is_available():
return torch.device('cpu')
local_rank = int(os.environ.get('LOCAL_RANK', '0'))
torch.cuda.set_device(local_rank)
return torch.device(f'cuda:{local_rank}')
def main() -> None:
args = parse_args()
from infer_runtime.model import InferenceParams, build_model
from infer_runtime.settings import load_settings
from modules.utils import maybe_init_distributed, clean_dist_env
from modules.models.attention import describe_attention_backend
dist_initialized = False
try:
settings = load_settings(
ckpt_root=args.ckpt_root,
config_path=args.config,
rewrite_model=args.rewrite_model,
default_seed=args.seed,
)
device = resolve_device()
dist_initialized = maybe_init_distributed()
if is_rank0():
print(f'Chosen device: {device}')
print(f'Attention backend: {describe_attention_backend()}')
print(f'Config path: {settings.config_path}')
print(f'Checkpoint path: {settings.ckpt_path}')
if args.hsdp_shard_dim is not None:
print(f'Override hsdp_shard_dim: {args.hsdp_shard_dim}')
model = build_model(
settings,
device=device,
hsdp_shard_dim_override=args.hsdp_shard_dim,
)
input_image = load_input_image(args.image)
effective_prompt = model.maybe_rewrite_prompt(args.prompt, input_image, args.rewrite_prompt)
start_time = time.time()
output_image = model.infer(
InferenceParams(
prompt=effective_prompt,
image=input_image,
height=args.height,
width=args.width,
steps=args.steps,
guidance_scale=args.guidance_scale,
seed=args.seed,
neg_prompt=args.neg_prompt,
basesize=args.basesize,
)
)
elapsed = time.time() - start_time
if is_rank0():
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_image.save(output_path)
print(f'Prompt used: {effective_prompt}')
print(f'Saved output: {output_path}')
print(f'Time taken: {elapsed:.2f} seconds')
finally:
if dist_initialized:
clean_dist_env()
if __name__ == '__main__':
main()