| import os |
| import argparse |
|
|
| import torch |
| from PIL import Image |
|
|
| from config import CONFIG |
| from model import load_from_checkpoint |
| from inference_utils import full_inference, make_side_by_side |
|
|
|
|
| def run_inference( |
| input_image: str, |
| output_dir: str = "./output", |
| ckpt_path: str = None, |
| device: str = None, |
| image_size: int = None, |
| width_mult: float = None, |
| ): |
| if device is None: |
| device = CONFIG["device"] |
| if image_size is None: |
| image_size = CONFIG["image_size"] |
| if width_mult is None: |
| width_mult = CONFIG["width_mult"] |
| if ckpt_path is None: |
| ckpt_path = CONFIG["checkpoint_path"] |
|
|
| os.makedirs(output_dir, exist_ok=True) |
|
|
| print(f"Loading model from {ckpt_path}...") |
| model = load_from_checkpoint(ckpt_path, device=device, width_mult=width_mult) |
| model.to(device) |
| model.eval() |
|
|
| img = Image.open(input_image).convert("RGB") |
| base = os.path.splitext(os.path.basename(input_image))[0] |
|
|
| with torch.no_grad(): |
| inp_img, outputs = full_inference(model, img, image_size, device) |
|
|
| side_by_side = make_side_by_side(inp_img, outputs) |
| out_path = os.path.join(output_dir, f"{base}_result.png") |
| side_by_side.save(out_path) |
| print(f"Saved result to {out_path}") |
|
|
| return outputs |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Diffusion Detailer Inference") |
| parser.add_argument("input", type=str, help="Path to input image") |
| parser.add_argument("--output_dir", type=str, default="./output", help="Output directory") |
| parser.add_argument("--ckpt_path", type=str, default=None, help="Path to checkpoint") |
| parser.add_argument("--device", type=str, default=None, help="Device (cuda/cpu)") |
| parser.add_argument("--image_size", type=int, default=None, help="Image size") |
| args = parser.parse_args() |
|
|
| run_inference( |
| input_image=args.input, |
| output_dir=args.output_dir, |
| ckpt_path=args.ckpt_path, |
| device=args.device, |
| image_size=args.image_size, |
| ) |
|
|