File size: 12,276 Bytes
5db43ff
 
 
 
 
 
 
 
 
 
 
509815c
5db43ff
 
ec8a465
 
 
5db43ff
17977ea
6b38677
17977ea
 
 
 
 
 
 
 
6b38677
17977ea
6b38677
17977ea
19f603d
17977ea
19f603d
17977ea
6b38677
17977ea
6b38677
17977ea
 
 
 
6b38677
5db43ff
17977ea
6b38677
17977ea
6b38677
5db43ff
17977ea
 
5db43ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d5371b
5db43ff
 
 
 
 
 
 
 
 
 
 
 
 
 
509815c
5db43ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224f0c5
5db43ff
224f0c5
5db43ff
224f0c5
 
 
 
 
5db43ff
224f0c5
 
5db43ff
509815c
7facf2a
5db43ff
 
 
7facf2a
 
 
5db43ff
224f0c5
 
 
 
 
 
 
 
 
 
 
5db43ff
 
 
 
 
 
 
 
224f0c5
 
 
 
5db43ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509815c
 
 
 
d207705
 
 
 
5db43ff
509815c
 
 
b514c84
 
 
509815c
5db43ff
 
509815c
 
5db43ff
 
 
 
b514c84
5db43ff
b514c84
 
 
5db43ff
b514c84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5db43ff
b514c84
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
import os
import subprocess
import sys
import shutil
import cv2
import gradio as gr
import torch
import numpy as np
from rembg import remove
from PIL import Image
import threading
import spaces
from glob import glob

# Set OpenGL Platform for headless rendering
os.environ["PYOPENGL_PLATFORM"] = "egl"

# --- Installation Helper ---
def install_dependencies():
    print("Checking and installing dependencies...", flush=True)
    try:
        # 1. Upgrade build tools (redundant but safe)
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "pip", "wheel", "setuptools", "ninja"])
        
        # 2. Check if Detectron2 is installed, if not install it
        # optimizing to avoid reinstall on container restart if determined present
        try:
             import detectron2
             print("Detectron2 already installed.", flush=True)
        except ImportError:
             print("Installing Detectron2...", flush=True)
             subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "detectron2@git+https://github.com/facebookresearch/detectron2.git"])

        # 3. Check if ROMP is installed
        try:
             import romp
             print("ROMP already installed.", flush=True)
        except ImportError:
             print("Installing ROMP...", flush=True)
             subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "git+https://github.com/ZaiqiangWu/ROMP.git#subdirectory=simple_romp"])
        
        # 4. Download Checkpoints
        if not os.path.exists("./rtv_ckpts"):
            print("Downloading checkpoints...", flush=True)
            subprocess.check_call(["git", "clone", "https://huggingface.co/wuzaiqiang/rtv_ckpts"])
        
        print("Dependencies installed and checkpoints ready.", flush=True)
    except Exception as e:
        print(f"Error installing dependencies: {e}", flush=True)

# Run installation on startup
install_dependencies()

# --- App Logic ---

# Ensure directories exist
os.makedirs("generated_masks", exist_ok=True)
os.makedirs("PerGarmentDatasets", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)

from VITON.viton_upperbody import FrameProcessor

frame_processor = None
current_garment_id = -1
garment_list = ["lab_03"] # Default pretrained

def get_garment_list():
    # Scan checkpoints for garments
    global garment_list
    ckpts = glob("checkpoints/*") + glob("rtv_ckpts/*")
    names = [os.path.basename(p) for p in ckpts if os.path.isdir(p) and "label" not in p]
    garment_list = list(set(names))
    return garment_list

def extract_frames_and_masks(video_path, garment_name):
    print(f"Processing video: {video_path}")
    cap = cv2.VideoCapture(video_path)
    mask_dir = f"generated_masks/{garment_name}"
    if os.path.exists(mask_dir):
        shutil.rmtree(mask_dir)
    os.makedirs(mask_dir)
    
    frame_count = 0
    # Limit frames for demo speed if needed, but RTV needs good coverage.
    # We'll take every frame but maybe limit total execution if video is too long.
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
            
        # Save Mask
        # Convert to PIL for rembg
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_im = Image.fromarray(frame_rgb)
        
        # Remove background (get mask)
        # rembg returns RGBA, alpha channel is the mask
        from rembg import remove
        output = remove(pil_im)
        mask = np.array(output)[:, :, 3] # Extract alpha
        
        # Save mask as single channel png
        mask_path = os.path.join(mask_dir, f"{str(frame_count).zfill(5)}.png")
        cv2.imwrite(mask_path, mask)
        
        frame_count += 1
        if frame_count % 50 == 0:
            print(f"Processed {frame_count} frames...")
            
    cap.release()
    return mask_dir

@spaces.GPU(duration=300)
def train_garment(video_path, garment_name):
    if not video_path:
        return "Please upload a video first.", gr.Dropdown.update(choices=get_garment_list())
    
    clean_name = garment_name.strip().replace(" ", "_")
    if not clean_name:
        return "Please enter a valid garment name.", gr.Dropdown.update(choices=get_garment_list())
        
    yield f"Starting training processing for {clean_name}...", gr.Dropdown.update()
    
    # 1. Generate Masks
    try:
        yield "Creating segmentation masks (this may take a while)...", gr.Dropdown.update()
        mask_dir = extract_frames_and_masks(video_path, clean_name)
    except Exception as e:
        return f"Error during masking: {e}", gr.Dropdown.update()

    # 2. Generate Dataset
    yield "Generating dataset...", gr.Dropdown.update()
    try:
        cmd_dataset = [
            sys.executable, "DatasetGeneration/upperbody_dataset_generation.py",
            "--video_path", video_path,
            "--mask_dir", mask_dir,
            "--dataset_name", clean_name
        ]
        subprocess.check_call(cmd_dataset)
    except Exception as e:
         return f"Error during dataset generation: {e}", gr.Dropdown.update()

    # 3. Train Model
    yield "Training model (this will take minutes)...", gr.Dropdown.update()
    try:
        # Reduced params for demo speed
        cmd_train = [
            sys.executable, "Training/upperbody_training.py",
            "--model", "pix2pixHD_RGBA",
            "--input_nc", "6",
            "--output_nc", "4",
            "--batchSize", "4",
            "--img_size", "512",
            "--dataset_path", f"./PerGarmentDatasets/{clean_name}",
            "--name", clean_name,
            "--niter", "20",       # Reduced from 80 for demo
            "--niter_decay", "20"  # Reduced from 80 for demo
        ]
        subprocess.check_call(cmd_train)
    except Exception as e:
        return f"Error during training: {e}", gr.Dropdown.update()
        
    # Copy checkpoint to main checkpoints dir/ensure visibility
    # The script saves to ./checkpoints/{name} by default
    
    new_list = get_garment_list()
    return f"Training complete for {clean_name}! You can now select it in the Try-On tab.", gr.Dropdown.update(choices=new_list, value=clean_name)

# --- Inference Logic ---

def init_processor(garment_name):
    # global frame_processor, current_garment_id # Avoid global reliance in helper
    if garment_name is None:
        return None
        
    print(f"Loading garment: {garment_name}", flush=True)
    # Initialize 
    # Always create new for now to ensure we have the right one
    processor = FrameProcessor([garment_name], ckpt_dir='./checkpoints') 
    
    # Trigger load
    processor.switch_to_target_garment(0)
    return processor

@spaces.GPU
def process_frame(image, garment_name, enable_tryon):
    if image is None:
        return None
    
    if not enable_tryon:
        return image
    
    global frame_processor
    
    try:
        # Check if we need to load/reload
        # We need to treat frame_processor as potentially stale or None
        should_reload = False
        if frame_processor is None:
            should_reload = True
        elif frame_processor.garment_name_list[0] != garment_name:
             should_reload = True
             
        if should_reload:
             # Link Pretrained to checkpoints if needed
            if os.path.exists(f"rtv_ckpts/{garment_name}") and not os.path.exists(f"checkpoints/{garment_name}"):
                if not os.path.exists("checkpoints"): os.makedirs("checkpoints")
                if not os.path.exists(f"checkpoints/{garment_name}"):
                     # Copy or symlink
                     import shutil
                     shutil.copytree(f"rtv_ckpts/{garment_name}", f"checkpoints/{garment_name}")

            frame_processor = init_processor(garment_name)
    except Exception as e:
        print(f"Error loading model: {e}", flush=True)
        return image

    # Convert to RGB (Gradio is RGB, OpenCV is BGR)
    # RTV expects BGR usually? checking rtl_demo...
    # rtl_demo: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) at end.
    # FrameProcessor takes "raw_image", likely BGR from cv2.read()
    
    img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    
    # Pre-processing from rtl_demo
    # frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE) # Assuming webcam is standard landscape, no rotate needed usually
    # frame=resize_img(frame,max_height=1024)
    # frame=crop2_169(frame)
    # For web demo, let's keep it simple. Resizing is good.
    h, w, _ = img_bgr.shape
    if h > 1024:
        scale = 1024 / h
        img_bgr = cv2.resize(img_bgr, (int(w*scale), 1024))
        
    try:
        if frame_processor is None:
             print("Error: Frame processor is None inside process_frame")
             return image
             
        # Debug Image Stats
        h, w, c = img_bgr.shape
        print(f"Inference Input - Shape: {h}x{w}, Mean: {img_bgr.mean():.2f}, Max: {img_bgr.max()}", flush=True)
        
        output_bgr = frame_processor(img_bgr)
        
        if output_bgr is None:
             print("Warning: FrameProcessor returned None")
             # Draw Text on Image
             cv2.putText(img_bgr, "NO PERSON DETECTED", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 3)
             return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
             
        return cv2.cvtColor(output_bgr, cv2.COLOR_BGR2RGB)
    except Exception as e:
        import traceback
        traceback.print_exc()
        print(f"Inference error: {e}")
        return image


    # --- Gradio UI ---

    with gr.Blocks(title="RTV: Real-Time Virtual Try-On") as app:
        gr.Markdown("# Real-Time Virtual Try-On")
        gr.Markdown("Train a custom garment model from a video, or try on existing ones.")

        with gr.Tab("Virtual Try-On"):
            tryon_enabled = gr.State(False)
            
            with gr.Row():
                with gr.Column():
                    garment_selector = gr.Dropdown(
                        label="Select Garment", 
                        choices=get_garment_list(),
                        value="lab_03" if "lab_03" in get_garment_list() else None,
                        interactive=True
                    )
                    input_webcam = gr.Image(sources=["webcam"], streaming=True, label="Live Feed", interactive=True)
                    with gr.Row():
                        start_btn = gr.Button("Start Try-On", variant="primary")
                        stop_btn = gr.Button("Stop Try-On", variant="stop")
                    with gr.Row():
                         stop_cam_btn = gr.Button("Turn Off Camera", variant="secondary")
                         
                with gr.Column():
                    output_display = gr.Image(label="Virtual Try-On Result")
            
            # Button handlers to toggle state
            start_btn.click(fn=lambda: True, inputs=None, outputs=[tryon_enabled])
            stop_btn.click(fn=lambda: False, inputs=None, outputs=[tryon_enabled])
            
            # Stop Camera Handler (Clears the input)
            stop_cam_btn.click(fn=lambda: None, inputs=None, outputs=[input_webcam])

            # Stream now listens to the state
            input_webcam.stream(
                fn=process_frame, 
                inputs=[input_webcam, garment_selector, tryon_enabled], 
                outputs=output_display,
                show_progress=False
            )
            
        with gr.Tab("Train New Garment"):
            gr.Markdown("Upload a video of the garment. The system will auto-mask frames and train a model.")
            with gr.Row():
                video_input = gr.Video(label="Upload Video (Mp4)")
                garment_name_input = gr.Textbox(label="Garment Name (e.g., 'red_shirt')", placeholder="my_custom_shirt")
                train_btn = gr.Button("Start Training")
            
            train_log = gr.Textbox(label="Training Status", interactive=False)
            
            train_btn.click(
                fn=train_garment,
                inputs=[video_input, garment_name_input],
                outputs=[train_log, garment_selector]
            )

    app.queue().launch()