Maarij-Aqeel commited on
Commit
bd3e71e
·
0 Parent(s):

Added files

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +208 -0
  3. requirements.txt +11 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .gradio
2
+ .venv
app.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ import rembg
7
+ from diffusers import AutoPipelineForInpainting
8
+
9
+ # ==============================================================================
10
+ # Virtual Try-On (VTON) Application
11
+ # Designed using PyTorch, Diffusers, and Gradio
12
+ #
13
+ # This script serves as the complete application. It handles model loading,
14
+ # image preprocessing, mask generation (via rembg/OpenCV), and inference.
15
+ # ==============================================================================
16
+
17
+ # Hardware setup - preferring CUDA if available with fp16 to optimize VRAM
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
20
+
21
+ class VTONPipeline:
22
+ def __init__(self):
23
+ self.pipe = None
24
+ self.rembg_session = None
25
+
26
+ def load_models(self):
27
+ """
28
+ Loads the necessary generative pipelines and segmentation tools.
29
+ Loads using torch.float16 for significant VRAM savings.
30
+ """
31
+ print(f"Loading Diffusers Model on {DEVICE}...")
32
+ try:
33
+ # NOTE: For state-of-the-art IDM-VTON or OOTDiffusion, you would usually load
34
+ # their specialized UNets and IP-Adapters here.
35
+ # We are using an SDXL Inpainting baseline to demonstrate the full logic.
36
+ self.pipe = AutoPipelineForInpainting.from_pretrained(
37
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
38
+ torch_dtype=DTYPE,
39
+ variant="fp16" if DTYPE == torch.float16 else None
40
+ ).to(DEVICE)
41
+
42
+ # Accelerate and VRAM optimizations
43
+ if DEVICE == "cuda":
44
+ self.pipe.enable_model_cpu_offload()
45
+ # Uncomment the below if xformers is successfully installed
46
+ # self.pipe.enable_xformers_memory_efficient_attention()
47
+
48
+ except Exception as e:
49
+ print(f"Failed to load the pipeline: {e}")
50
+ raise e
51
+
52
+ # Initialize rembg session for background/body mask generation
53
+ self.rembg_session = rembg.new_session()
54
+ print("Models loaded successfully.")
55
+
56
+ def generate_mask(self, person_image: Image.Image) -> Image.Image:
57
+ """
58
+ [CRUCIAL] Generates a mask indicating where the new garment will be drawn.
59
+ Uses rembg to isolate the person and morphological operations to create the inpaint mask.
60
+ """
61
+ # 1. Use rembg to extract the person from the background
62
+ person_no_bg = rembg.remove(person_image, session=self.rembg_session)
63
+ np_img = np.array(person_no_bg)
64
+
65
+ # 2. Extract the alpha channel (0 = background, 255 = person)
66
+ alpha = np_img[:, :, 3]
67
+
68
+ # 3. Create a mask over the clothing area. In a pure VTON application,
69
+ # a human parser (like Graphonomy or MediaPipe pose) is used to pinpoint the shirt.
70
+ # Here we approximate by capturing the body silhouette and dilating it slightly
71
+ # to ensure the garment bounds are fully covered for inpainting.
72
+ kernel = np.ones((25, 25), np.uint8)
73
+ dilated_mask = cv2.dilate(alpha, kernel, iterations=1)
74
+
75
+ # 4. Convert back to PIL Image, grayscale ('L')
76
+ mask_image = Image.fromarray(dilated_mask).convert("L")
77
+ return mask_image
78
+
79
+ def preprocess_image(self, image: Image.Image, target_size=(768, 1024)) -> Image.Image:
80
+ """
81
+ Resizes and center-crops the input image to fit the specific resolution
82
+ required by diffusion models (usually 768x1024 for VTON like IDM-VTON).
83
+ """
84
+ w, h = image.size
85
+ target_w, target_h = target_size
86
+
87
+ # Calculate aspect ratio
88
+ ratio = max(target_w / w, target_h / h)
89
+ new_w, new_h = int(w * ratio), int(h * ratio)
90
+
91
+ # Resize using high-quality Lanczos filter
92
+ img_resized = image.resize((new_w, new_h), Image.LANCZOS)
93
+
94
+ # Center crop precisely
95
+ left = (new_w - target_w) / 2
96
+ top = (new_h - target_h) / 2
97
+ right = (new_w + target_w) / 2
98
+ bottom = (new_h + target_h) / 2
99
+
100
+ return img_resized.crop((left, top, right, bottom))
101
+
102
+ def try_on(self, person_img: Image.Image, garment_img: Image.Image) -> Image.Image:
103
+ """
104
+ Core inference pipeline.
105
+ Executes the heavy lifting: preprocessing -> masking -> diffusion loop.
106
+ """
107
+ if person_img is None or garment_img is None:
108
+ raise gr.Error("Both 'Target Person' and 'Garment' images are required.")
109
+
110
+ # Lazy load models to speed up initial app startup
111
+ if self.pipe is None:
112
+ self.load_models()
113
+
114
+ print("Preprocessing Inputs...")
115
+ # Most VTON networks operate optimally at 768 Width x 1024 Height
116
+ target_resolution = (768, 1024)
117
+ person_prepared = self.preprocess_image(person_img.convert("RGB"), target_resolution)
118
+ garment_prepared = self.preprocess_image(garment_img.convert("RGB"), target_resolution)
119
+
120
+ print("Generating Mask...")
121
+ mask_prepared = self.generate_mask(person_prepared)
122
+
123
+ print("Running Inpainting Inference...")
124
+
125
+ # Prompting: Describe exactly what we want.
126
+ # Note: Advanced architectures like IDM-VTON will use IP-Adapter
127
+ # to embed the garment_img directly as semantic features.
128
+ prompt = "photorealistic, a person wearing the provided garment, perfect fit, detailed fabric texture, high quality, 8k"
129
+ negative_prompt = "deformed, ugly, bad anatomy, bad lighting, blurry, low resolution, artifacts, extra limbs"
130
+
131
+ # Perform Inference
132
+ result_img = self.pipe(
133
+ prompt=prompt,
134
+ negative_prompt=negative_prompt,
135
+ image=person_prepared,
136
+ mask_image=mask_prepared,
137
+ num_inference_steps=30, # Balance between speed and quality
138
+ guidance_scale=7.5,
139
+ ).images[0]
140
+
141
+ return result_img
142
+
143
+
144
+ # Instantiate the logic runner
145
+ vton_worker = VTONPipeline()
146
+
147
+ def run_vton_interface(person_img, garment_img):
148
+ """
149
+ Gradio wrapper function to capture output and gracefully catch errors.
150
+ """
151
+ try:
152
+ # Pass to backend processor
153
+ output = vton_worker.try_on(person_img, garment_img)
154
+ return output
155
+ except gr.Error as ge:
156
+ # Standard gradio errors shown instantly to users
157
+ raise ge
158
+ except Exception as e:
159
+ # Unexpected errors logged and displayed appropriately
160
+ print(f"Exception during Try-On: {e}")
161
+ raise gr.Error(f"Model Inference Failed: {str(e)}")
162
+
163
+
164
+ # ==============================================================================
165
+ # Gradio Web UI
166
+ # ==============================================================================
167
+ def create_ui():
168
+ """
169
+ Constructs an aesthetically pleasing user interface utilizing Gradio.
170
+ """
171
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", font=[gr.themes.GoogleFont("Inter")])) as demo:
172
+ gr.Markdown(
173
+ """
174
+ <div style="text-align: center; margin-bottom: 20px;">
175
+ <h1>👗 AI Virtual Try-On (VTON) Studio</h1>
176
+ <p>Upload a photo of a person and an isolated clothing garment to digitally try it on using Generative AI.</p>
177
+ </div>
178
+ """
179
+ )
180
+
181
+ with gr.Row():
182
+ with gr.Column(scale=1):
183
+ gr.Markdown("### 1. Upload Inputs")
184
+ person_input = gr.Image(type="pil", label="👤 Target Person", height=450)
185
+ garment_input = gr.Image(type="pil", label="👕 Garment to Try", height=450)
186
+
187
+ generate_btn = gr.Button("✨ Generate Try-On", variant="primary", size="lg")
188
+
189
+ with gr.Column(scale=1):
190
+ gr.Markdown("### 2. Resulting Image")
191
+ output_image = gr.Image(type="pil", label="Expected Result", interactive=False, height=930)
192
+
193
+ # Connect front-end inputs to back-end function
194
+ # A loading spinner is automatically displayed on `generate_btn` click.
195
+ generate_btn.click(
196
+ fn=run_vton_interface,
197
+ inputs=[person_input, garment_input],
198
+ outputs=[output_image],
199
+ api_name="generate"
200
+ )
201
+
202
+ return demo
203
+
204
+ if __name__ == "__main__":
205
+ app = create_ui()
206
+ # Queue is required to handle concurrent tasks safely
207
+ app.queue()
208
+ app.launch(server_name="0.0.0.0", server_port=7860, share=False)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision
3
+ diffusers>=0.28.0
4
+ transformers
5
+ accelerate
6
+ gradio
7
+ rembg
8
+ pillow
9
+ opencv-python
10
+ numpy
11
+ xformers