File size: 4,148 Bytes
756b108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python3
"""Basic inference example."""

import argparse
import sys
from pathlib import Path

from PIL import Image

from fashn_vton import TryOnPipeline


def main():
    parser = argparse.ArgumentParser(
        description="FASHN VTON v1.5",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Example:
    python examples/basic_inference.py \\
        --weights-dir ./weights \\
        --person-image examples/data/model.webp \\
        --garment-image examples/data/garment.webp \\
        --category tops
        """,
    )
    parser.add_argument("--weights-dir", type=str, required=True, help="Directory containing model weights")
    parser.add_argument("--person-image", type=str, required=True, help="Path to person image")
    parser.add_argument("--garment-image", type=str, required=True, help="Path to garment image")
    parser.add_argument(
        "--category",
        type=str,
        choices=["tops", "bottoms", "one-pieces"],
        required=True,
        help="Garment category to try on",
    )
    parser.add_argument(
        "--garment-photo-type",
        type=str,
        choices=["model", "flat-lay"],
        default="model",
        help="'model' if worn by person, 'flat-lay' for product shots",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="outputs",
        help="Output directory (created if doesn't exist)",
    )
    parser.add_argument(
        "--num-samples",
        type=int,
        default=1,
        help="Number of output images to generate (1-4)",
    )
    parser.add_argument(
        "--num-timesteps",
        type=int,
        default=30,
        help="Diffusion steps: 20=fast, 30=balanced, 50=quality",
    )
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
    parser.add_argument("--guidance-scale", type=float, default=1.5, help="Classifier-free guidance strength")
    parser.add_argument(
        "--no-segmentation-free",
        action="store_false",
        dest="segmentation_free",
        default=True,
        help="Disable segmentation-free mode. Default (enabled) preserves body features and allows unconstrained garment volume",
    )
    parser.add_argument("--device", type=str, default=None, help="Device to use (cuda/cpu)")
    args = parser.parse_args()

    # Validate inputs exist
    person_path = Path(args.person_image)
    garment_path = Path(args.garment_image)
    weights_path = Path(args.weights_dir)

    if not person_path.exists():
        print(f"Error: Person image not found: {person_path}")
        sys.exit(1)
    if not garment_path.exists():
        print(f"Error: Garment image not found: {garment_path}")
        sys.exit(1)
    if not weights_path.exists():
        print(f"Error: Weights directory not found: {weights_path}")
        print(f"Run: python scripts/download_weights.py --weights-dir {args.weights_dir}")
        sys.exit(1)

    # Load images
    print("Loading images...")
    person_image = Image.open(args.person_image).convert("RGB")
    garment_image = Image.open(args.garment_image).convert("RGB")

    # Create pipeline (loads all models internally)
    print(f"Loading pipeline from {args.weights_dir}...")
    pipeline = TryOnPipeline(weights_dir=args.weights_dir, device=args.device)

    # Run inference
    result = pipeline(
        person_image=person_image,
        garment_image=garment_image,
        category=args.category,
        garment_photo_type=args.garment_photo_type,
        num_samples=args.num_samples,
        num_timesteps=args.num_timesteps,
        guidance_scale=args.guidance_scale,
        seed=args.seed,
        segmentation_free=args.segmentation_free,
    )

    # Save outputs
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    for i, output_image in enumerate(result.images):
        output_path = output_dir / f"output_{i:02d}.png"
        output_image.save(output_path)
        print(f"Saved: {output_path}")

    print(f"\nDone! Generated {len(result.images)} images.")


if __name__ == "__main__":
    main()