Spaces:
Sleeping
Sleeping
| """ | |
| inference.py - CLI inference script. | |
| Usage: | |
| python scripts/inference.py \ | |
| --image input.jpg \ | |
| --prompt "make it look like sunset" \ | |
| --checkpoint checkpoints/diffusion_final.pt \ | |
| --vae-checkpoint checkpoints/vae_final.pt \ | |
| --output outputs/edited.png | |
| """ | |
| import argparse | |
| import os | |
| import sys | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import torch | |
| from PIL import Image | |
| from omegaconf import OmegaConf | |
| from model.pipeline import EditPipeline | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Edit an image with a text prompt") | |
| parser.add_argument("--image", type=str, required=True, help="Input image path") | |
| parser.add_argument("--prompt", type=str, required=True, help="Edit instruction") | |
| parser.add_argument("--checkpoint", type=str, required=True, help="Diffusion model checkpoint") | |
| parser.add_argument("--vae-checkpoint", type=str, default=None, help="VAE checkpoint (if separate)") | |
| parser.add_argument("--config", type=str, default="config/default.yaml") | |
| parser.add_argument("--output", type=str, default="outputs/edited.png") | |
| parser.add_argument("--steps", type=int, default=50, help="DDIM sampling steps") | |
| parser.add_argument("--text-scale", type=float, default=7.5, help="Text guidance scale") | |
| parser.add_argument("--image-scale", type=float, default=1.5, help="Image guidance scale") | |
| parser.add_argument("--seed", type=int, default=None, help="Random seed") | |
| parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") | |
| args = parser.parse_args() | |
| base_config = OmegaConf.load("config/default.yaml") | |
| override_config = OmegaConf.load(args.config) | |
| config = OmegaConf.to_container(OmegaConf.merge(base_config, override_config), resolve=True) | |
| print(f"Loading model from {args.checkpoint}...") | |
| pipeline = EditPipeline.from_checkpoint( | |
| args.checkpoint, | |
| vae_checkpoint_path=args.vae_checkpoint, | |
| config=config, | |
| device=args.device, | |
| ) | |
| image = Image.open(args.image).convert("RGB") | |
| print(f"Input: {args.image} ({image.size})") | |
| print(f"Prompt: {args.prompt}") | |
| print(f"Steps: {args.steps}, Text scale: {args.text_scale}, Image scale: {args.image_scale}") | |
| edited = pipeline.edit( | |
| image=image, | |
| prompt=args.prompt, | |
| num_steps=args.steps, | |
| text_guidance_scale=args.text_scale, | |
| image_guidance_scale=args.image_scale, | |
| seed=args.seed, | |
| ) | |
| os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) | |
| edited.save(args.output) | |
| print(f"Saved edited image: {args.output}") | |
| if __name__ == "__main__": | |
| main() | |