salu_Image_Editter / scripts /inference.py
Raghava Pulugu
Clean deployment
cad10d9
Raw
History Blame Contribute Delete
2.73 kB
"""
inference.py - CLI inference script.
Usage:
python scripts/inference.py \
--image input.jpg \
--prompt "make it look like sunset" \
--checkpoint checkpoints/diffusion_final.pt \
--vae-checkpoint checkpoints/vae_final.pt \
--output outputs/edited.png
"""
import argparse
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
from PIL import Image
from omegaconf import OmegaConf
from model.pipeline import EditPipeline
def main():
parser = argparse.ArgumentParser(description="Edit an image with a text prompt")
parser.add_argument("--image", type=str, required=True, help="Input image path")
parser.add_argument("--prompt", type=str, required=True, help="Edit instruction")
parser.add_argument("--checkpoint", type=str, required=True, help="Diffusion model checkpoint")
parser.add_argument("--vae-checkpoint", type=str, default=None, help="VAE checkpoint (if separate)")
parser.add_argument("--config", type=str, default="config/default.yaml")
parser.add_argument("--output", type=str, default="outputs/edited.png")
parser.add_argument("--steps", type=int, default=50, help="DDIM sampling steps")
parser.add_argument("--text-scale", type=float, default=7.5, help="Text guidance scale")
parser.add_argument("--image-scale", type=float, default=1.5, help="Image guidance scale")
parser.add_argument("--seed", type=int, default=None, help="Random seed")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
base_config = OmegaConf.load("config/default.yaml")
override_config = OmegaConf.load(args.config)
config = OmegaConf.to_container(OmegaConf.merge(base_config, override_config), resolve=True)
print(f"Loading model from {args.checkpoint}...")
pipeline = EditPipeline.from_checkpoint(
args.checkpoint,
vae_checkpoint_path=args.vae_checkpoint,
config=config,
device=args.device,
)
image = Image.open(args.image).convert("RGB")
print(f"Input: {args.image} ({image.size})")
print(f"Prompt: {args.prompt}")
print(f"Steps: {args.steps}, Text scale: {args.text_scale}, Image scale: {args.image_scale}")
edited = pipeline.edit(
image=image,
prompt=args.prompt,
num_steps=args.steps,
text_guidance_scale=args.text_scale,
image_guidance_scale=args.image_scale,
seed=args.seed,
)
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
edited.save(args.output)
print(f"Saved edited image: {args.output}")
if __name__ == "__main__":
main()