|
|
import os |
|
|
import subprocess |
|
|
import sys |
|
|
import shutil |
|
|
import cv2 |
|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
from rembg import remove |
|
|
from PIL import Image |
|
|
import threading |
|
|
import spaces |
|
|
from glob import glob |
|
|
|
|
|
|
|
|
os.environ["PYOPENGL_PLATFORM"] = "egl" |
|
|
|
|
|
|
|
|
def install_dependencies(): |
|
|
print("Checking and installing dependencies...", flush=True) |
|
|
try: |
|
|
|
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "pip", "wheel", "setuptools", "ninja"]) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
import detectron2 |
|
|
print("Detectron2 already installed.", flush=True) |
|
|
except ImportError: |
|
|
print("Installing Detectron2...", flush=True) |
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "detectron2@git+https://github.com/facebookresearch/detectron2.git"]) |
|
|
|
|
|
|
|
|
try: |
|
|
import romp |
|
|
print("ROMP already installed.", flush=True) |
|
|
except ImportError: |
|
|
print("Installing ROMP...", flush=True) |
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "git+https://github.com/ZaiqiangWu/ROMP.git#subdirectory=simple_romp"]) |
|
|
|
|
|
|
|
|
if not os.path.exists("./rtv_ckpts"): |
|
|
print("Downloading checkpoints...", flush=True) |
|
|
subprocess.check_call(["git", "clone", "https://huggingface.co/wuzaiqiang/rtv_ckpts"]) |
|
|
|
|
|
print("Dependencies installed and checkpoints ready.", flush=True) |
|
|
except Exception as e: |
|
|
print(f"Error installing dependencies: {e}", flush=True) |
|
|
|
|
|
|
|
|
install_dependencies() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.makedirs("generated_masks", exist_ok=True) |
|
|
os.makedirs("PerGarmentDatasets", exist_ok=True) |
|
|
os.makedirs("checkpoints", exist_ok=True) |
|
|
|
|
|
from VITON.viton_upperbody import FrameProcessor |
|
|
|
|
|
frame_processor = None |
|
|
current_garment_id = -1 |
|
|
garment_list = ["lab_03"] |
|
|
|
|
|
def get_garment_list(): |
|
|
|
|
|
global garment_list |
|
|
ckpts = glob("checkpoints/*") + glob("rtv_ckpts/*") |
|
|
names = [os.path.basename(p) for p in ckpts if os.path.isdir(p) and "label" not in p] |
|
|
garment_list = list(set(names)) |
|
|
return garment_list |
|
|
|
|
|
def extract_frames_and_masks(video_path, garment_name): |
|
|
print(f"Processing video: {video_path}") |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
mask_dir = f"generated_masks/{garment_name}" |
|
|
if os.path.exists(mask_dir): |
|
|
shutil.rmtree(mask_dir) |
|
|
os.makedirs(mask_dir) |
|
|
|
|
|
frame_count = 0 |
|
|
|
|
|
|
|
|
|
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
pil_im = Image.fromarray(frame_rgb) |
|
|
|
|
|
|
|
|
|
|
|
from rembg import remove |
|
|
output = remove(pil_im) |
|
|
mask = np.array(output)[:, :, 3] |
|
|
|
|
|
|
|
|
mask_path = os.path.join(mask_dir, f"{str(frame_count).zfill(5)}.png") |
|
|
cv2.imwrite(mask_path, mask) |
|
|
|
|
|
frame_count += 1 |
|
|
if frame_count % 50 == 0: |
|
|
print(f"Processed {frame_count} frames...") |
|
|
|
|
|
cap.release() |
|
|
return mask_dir |
|
|
|
|
|
@spaces.GPU(duration=300) |
|
|
def train_garment(video_path, garment_name): |
|
|
if not video_path: |
|
|
return "Please upload a video first.", gr.Dropdown.update(choices=get_garment_list()) |
|
|
|
|
|
clean_name = garment_name.strip().replace(" ", "_") |
|
|
if not clean_name: |
|
|
return "Please enter a valid garment name.", gr.Dropdown.update(choices=get_garment_list()) |
|
|
|
|
|
yield f"Starting training processing for {clean_name}...", gr.Dropdown.update() |
|
|
|
|
|
|
|
|
try: |
|
|
yield "Creating segmentation masks (this may take a while)...", gr.Dropdown.update() |
|
|
mask_dir = extract_frames_and_masks(video_path, clean_name) |
|
|
except Exception as e: |
|
|
return f"Error during masking: {e}", gr.Dropdown.update() |
|
|
|
|
|
|
|
|
yield "Generating dataset...", gr.Dropdown.update() |
|
|
try: |
|
|
cmd_dataset = [ |
|
|
sys.executable, "DatasetGeneration/upperbody_dataset_generation.py", |
|
|
"--video_path", video_path, |
|
|
"--mask_dir", mask_dir, |
|
|
"--dataset_name", clean_name |
|
|
] |
|
|
subprocess.check_call(cmd_dataset) |
|
|
except Exception as e: |
|
|
return f"Error during dataset generation: {e}", gr.Dropdown.update() |
|
|
|
|
|
|
|
|
yield "Training model (this will take minutes)...", gr.Dropdown.update() |
|
|
try: |
|
|
|
|
|
cmd_train = [ |
|
|
sys.executable, "Training/upperbody_training.py", |
|
|
"--model", "pix2pixHD_RGBA", |
|
|
"--input_nc", "6", |
|
|
"--output_nc", "4", |
|
|
"--batchSize", "4", |
|
|
"--img_size", "512", |
|
|
"--dataset_path", f"./PerGarmentDatasets/{clean_name}", |
|
|
"--name", clean_name, |
|
|
"--niter", "20", |
|
|
"--niter_decay", "20" |
|
|
] |
|
|
subprocess.check_call(cmd_train) |
|
|
except Exception as e: |
|
|
return f"Error during training: {e}", gr.Dropdown.update() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_list = get_garment_list() |
|
|
return f"Training complete for {clean_name}! You can now select it in the Try-On tab.", gr.Dropdown.update(choices=new_list, value=clean_name) |
|
|
|
|
|
|
|
|
|
|
|
def init_processor(garment_name): |
|
|
|
|
|
if garment_name is None: |
|
|
return None |
|
|
|
|
|
print(f"Loading garment: {garment_name}", flush=True) |
|
|
|
|
|
|
|
|
processor = FrameProcessor([garment_name], ckpt_dir='./checkpoints') |
|
|
|
|
|
|
|
|
processor.switch_to_target_garment(0) |
|
|
return processor |
|
|
|
|
|
@spaces.GPU |
|
|
def process_frame(image, garment_name, enable_tryon): |
|
|
if image is None: |
|
|
return None |
|
|
|
|
|
if not enable_tryon: |
|
|
return image |
|
|
|
|
|
global frame_processor |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
should_reload = False |
|
|
if frame_processor is None: |
|
|
should_reload = True |
|
|
elif frame_processor.garment_name_list[0] != garment_name: |
|
|
should_reload = True |
|
|
|
|
|
if should_reload: |
|
|
|
|
|
if os.path.exists(f"rtv_ckpts/{garment_name}") and not os.path.exists(f"checkpoints/{garment_name}"): |
|
|
if not os.path.exists("checkpoints"): os.makedirs("checkpoints") |
|
|
if not os.path.exists(f"checkpoints/{garment_name}"): |
|
|
|
|
|
import shutil |
|
|
shutil.copytree(f"rtv_ckpts/{garment_name}", f"checkpoints/{garment_name}") |
|
|
|
|
|
frame_processor = init_processor(garment_name) |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}", flush=True) |
|
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
h, w, _ = img_bgr.shape |
|
|
if h > 1024: |
|
|
scale = 1024 / h |
|
|
img_bgr = cv2.resize(img_bgr, (int(w*scale), 1024)) |
|
|
|
|
|
try: |
|
|
if frame_processor is None: |
|
|
print("Error: Frame processor is None inside process_frame") |
|
|
return image |
|
|
|
|
|
|
|
|
h, w, c = img_bgr.shape |
|
|
print(f"Inference Input - Shape: {h}x{w}, Mean: {img_bgr.mean():.2f}, Max: {img_bgr.max()}", flush=True) |
|
|
|
|
|
output_bgr = frame_processor(img_bgr) |
|
|
|
|
|
if output_bgr is None: |
|
|
print("Warning: FrameProcessor returned None") |
|
|
|
|
|
cv2.putText(img_bgr, "NO PERSON DETECTED", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 3) |
|
|
return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
return cv2.cvtColor(output_bgr, cv2.COLOR_BGR2RGB) |
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
print(f"Inference error: {e}") |
|
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="RTV: Real-Time Virtual Try-On") as app: |
|
|
gr.Markdown("# Real-Time Virtual Try-On") |
|
|
gr.Markdown("Train a custom garment model from a video, or try on existing ones.") |
|
|
|
|
|
with gr.Tab("Virtual Try-On"): |
|
|
tryon_enabled = gr.State(False) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
garment_selector = gr.Dropdown( |
|
|
label="Select Garment", |
|
|
choices=get_garment_list(), |
|
|
value="lab_03" if "lab_03" in get_garment_list() else None, |
|
|
interactive=True |
|
|
) |
|
|
input_webcam = gr.Image(sources=["webcam"], streaming=True, label="Live Feed", interactive=True) |
|
|
with gr.Row(): |
|
|
start_btn = gr.Button("Start Try-On", variant="primary") |
|
|
stop_btn = gr.Button("Stop Try-On", variant="stop") |
|
|
with gr.Row(): |
|
|
stop_cam_btn = gr.Button("Turn Off Camera", variant="secondary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_display = gr.Image(label="Virtual Try-On Result") |
|
|
|
|
|
|
|
|
start_btn.click(fn=lambda: True, inputs=None, outputs=[tryon_enabled]) |
|
|
stop_btn.click(fn=lambda: False, inputs=None, outputs=[tryon_enabled]) |
|
|
|
|
|
|
|
|
stop_cam_btn.click(fn=lambda: None, inputs=None, outputs=[input_webcam]) |
|
|
|
|
|
|
|
|
input_webcam.stream( |
|
|
fn=process_frame, |
|
|
inputs=[input_webcam, garment_selector, tryon_enabled], |
|
|
outputs=output_display, |
|
|
show_progress=False |
|
|
) |
|
|
|
|
|
with gr.Tab("Train New Garment"): |
|
|
gr.Markdown("Upload a video of the garment. The system will auto-mask frames and train a model.") |
|
|
with gr.Row(): |
|
|
video_input = gr.Video(label="Upload Video (Mp4)") |
|
|
garment_name_input = gr.Textbox(label="Garment Name (e.g., 'red_shirt')", placeholder="my_custom_shirt") |
|
|
train_btn = gr.Button("Start Training") |
|
|
|
|
|
train_log = gr.Textbox(label="Training Status", interactive=False) |
|
|
|
|
|
train_btn.click( |
|
|
fn=train_garment, |
|
|
inputs=[video_input, garment_name_input], |
|
|
outputs=[train_log, garment_selector] |
|
|
) |
|
|
|
|
|
app.queue().launch() |
|
|
|