Spaces:
Running
Running
File size: 4,148 Bytes
756b108 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | #!/usr/bin/env python3
"""Basic inference example."""
import argparse
import sys
from pathlib import Path
from PIL import Image
from fashn_vton import TryOnPipeline
def main():
parser = argparse.ArgumentParser(
description="FASHN VTON v1.5",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Example:
python examples/basic_inference.py \\
--weights-dir ./weights \\
--person-image examples/data/model.webp \\
--garment-image examples/data/garment.webp \\
--category tops
""",
)
parser.add_argument("--weights-dir", type=str, required=True, help="Directory containing model weights")
parser.add_argument("--person-image", type=str, required=True, help="Path to person image")
parser.add_argument("--garment-image", type=str, required=True, help="Path to garment image")
parser.add_argument(
"--category",
type=str,
choices=["tops", "bottoms", "one-pieces"],
required=True,
help="Garment category to try on",
)
parser.add_argument(
"--garment-photo-type",
type=str,
choices=["model", "flat-lay"],
default="model",
help="'model' if worn by person, 'flat-lay' for product shots",
)
parser.add_argument(
"--output-dir",
type=str,
default="outputs",
help="Output directory (created if doesn't exist)",
)
parser.add_argument(
"--num-samples",
type=int,
default=1,
help="Number of output images to generate (1-4)",
)
parser.add_argument(
"--num-timesteps",
type=int,
default=30,
help="Diffusion steps: 20=fast, 30=balanced, 50=quality",
)
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
parser.add_argument("--guidance-scale", type=float, default=1.5, help="Classifier-free guidance strength")
parser.add_argument(
"--no-segmentation-free",
action="store_false",
dest="segmentation_free",
default=True,
help="Disable segmentation-free mode. Default (enabled) preserves body features and allows unconstrained garment volume",
)
parser.add_argument("--device", type=str, default=None, help="Device to use (cuda/cpu)")
args = parser.parse_args()
# Validate inputs exist
person_path = Path(args.person_image)
garment_path = Path(args.garment_image)
weights_path = Path(args.weights_dir)
if not person_path.exists():
print(f"Error: Person image not found: {person_path}")
sys.exit(1)
if not garment_path.exists():
print(f"Error: Garment image not found: {garment_path}")
sys.exit(1)
if not weights_path.exists():
print(f"Error: Weights directory not found: {weights_path}")
print(f"Run: python scripts/download_weights.py --weights-dir {args.weights_dir}")
sys.exit(1)
# Load images
print("Loading images...")
person_image = Image.open(args.person_image).convert("RGB")
garment_image = Image.open(args.garment_image).convert("RGB")
# Create pipeline (loads all models internally)
print(f"Loading pipeline from {args.weights_dir}...")
pipeline = TryOnPipeline(weights_dir=args.weights_dir, device=args.device)
# Run inference
result = pipeline(
person_image=person_image,
garment_image=garment_image,
category=args.category,
garment_photo_type=args.garment_photo_type,
num_samples=args.num_samples,
num_timesteps=args.num_timesteps,
guidance_scale=args.guidance_scale,
seed=args.seed,
segmentation_free=args.segmentation_free,
)
# Save outputs
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for i, output_image in enumerate(result.images):
output_path = output_dir / f"output_{i:02d}.png"
output_image.save(output_path)
print(f"Saved: {output_path}")
print(f"\nDone! Generated {len(result.images)} images.")
if __name__ == "__main__":
main()
|