#!/usr/bin/env python3 """Basic inference example.""" import argparse import sys from pathlib import Path from PIL import Image from fashn_vton import TryOnPipeline def main(): parser = argparse.ArgumentParser( description="FASHN VTON v1.5", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Example: python examples/basic_inference.py \\ --weights-dir ./weights \\ --person-image examples/data/model.webp \\ --garment-image examples/data/garment.webp \\ --category tops """, ) parser.add_argument("--weights-dir", type=str, required=True, help="Directory containing model weights") parser.add_argument("--person-image", type=str, required=True, help="Path to person image") parser.add_argument("--garment-image", type=str, required=True, help="Path to garment image") parser.add_argument( "--category", type=str, choices=["tops", "bottoms", "one-pieces"], required=True, help="Garment category to try on", ) parser.add_argument( "--garment-photo-type", type=str, choices=["model", "flat-lay"], default="model", help="'model' if worn by person, 'flat-lay' for product shots", ) parser.add_argument( "--output-dir", type=str, default="outputs", help="Output directory (created if doesn't exist)", ) parser.add_argument( "--num-samples", type=int, default=1, help="Number of output images to generate (1-4)", ) parser.add_argument( "--num-timesteps", type=int, default=30, help="Diffusion steps: 20=fast, 30=balanced, 50=quality", ) parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") parser.add_argument("--guidance-scale", type=float, default=1.5, help="Classifier-free guidance strength") parser.add_argument( "--no-segmentation-free", action="store_false", dest="segmentation_free", default=True, help="Disable segmentation-free mode. Default (enabled) preserves body features and allows unconstrained garment volume", ) parser.add_argument("--device", type=str, default=None, help="Device to use (cuda/cpu)") args = parser.parse_args() # Validate inputs exist person_path = Path(args.person_image) garment_path = Path(args.garment_image) weights_path = Path(args.weights_dir) if not person_path.exists(): print(f"Error: Person image not found: {person_path}") sys.exit(1) if not garment_path.exists(): print(f"Error: Garment image not found: {garment_path}") sys.exit(1) if not weights_path.exists(): print(f"Error: Weights directory not found: {weights_path}") print(f"Run: python scripts/download_weights.py --weights-dir {args.weights_dir}") sys.exit(1) # Load images print("Loading images...") person_image = Image.open(args.person_image).convert("RGB") garment_image = Image.open(args.garment_image).convert("RGB") # Create pipeline (loads all models internally) print(f"Loading pipeline from {args.weights_dir}...") pipeline = TryOnPipeline(weights_dir=args.weights_dir, device=args.device) # Run inference result = pipeline( person_image=person_image, garment_image=garment_image, category=args.category, garment_photo_type=args.garment_photo_type, num_samples=args.num_samples, num_timesteps=args.num_timesteps, guidance_scale=args.guidance_scale, seed=args.seed, segmentation_free=args.segmentation_free, ) # Save outputs output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) for i, output_image in enumerate(result.images): output_path = output_dir / f"output_{i:02d}.png" output_image.save(output_path) print(f"Saved: {output_path}") print(f"\nDone! Generated {len(result.images)} images.") if __name__ == "__main__": main()