zyuzuguldu's picture
Deploy professional garment segmentation app with custom UI
2f6b8cf
import gradio as gr
import torch
import cv2
import numpy as np
import segmentation_models_pytorch as smp
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from PIL import Image
import matplotlib.pyplot as plt
# Configuration
MODEL_REPO_ID = "zyuzuguldu/garment-segmentation-unet-resnet50"
INPUT_SIZE = 768
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Cache the model globally
model = None
def load_model():
"""Load the segmentation model from HuggingFace Hub."""
global model
if model is None:
print("📥 Downloading model from HuggingFace Hub...")
model_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename="model.safetensors"
)
print("🔨 Building model architecture...")
model = smp.Unet(
encoder_name="resnet50",
classes=1,
activation=None,
decoder_channels=(256, 128, 64, 32, 16)
)
print("⚡ Loading weights...")
state_dict = load_file(model_path)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
print("✅ Model loaded successfully!")
return model
def preprocess_image(image):
"""Preprocess image for model inference."""
# Convert PIL to numpy
if isinstance(image, Image.Image):
image = np.array(image)
# Ensure RGB
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# Store original size
original_size = image.shape[:2]
# Resize to model input size
image_resized = cv2.resize(image, (INPUT_SIZE, INPUT_SIZE))
# Normalize (ImageNet stats)
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image_normalized = (image_resized / 255.0 - mean) / std
# Convert to tensor: (H, W, C) -> (1, C, H, W)
image_tensor = torch.from_numpy(image_normalized).float().permute(2, 0, 1).unsqueeze(0)
return image_tensor, original_size, image
def postprocess_mask(mask_logits, original_size, threshold=0.5):
"""Postprocess model output to binary mask."""
# Apply sigmoid and threshold
mask_prob = torch.sigmoid(mask_logits).squeeze().cpu().numpy()
mask_binary = (mask_prob > threshold).astype(np.uint8)
# Resize back to original size
mask_resized = cv2.resize(mask_binary, (original_size[1], original_size[0]),
interpolation=cv2.INTER_NEAREST)
return mask_resized, mask_prob
def create_overlay(image, mask, alpha=0.6):
"""Create an overlay visualization of mask on image."""
# Create colored mask (cyan/turquoise color)
colored_mask = np.zeros_like(image)
colored_mask[:, :, 0] = mask * 0 # Red channel
colored_mask[:, :, 1] = mask * 255 # Green channel
colored_mask[:, :, 2] = mask * 255 # Blue channel
# Blend with original image
overlay = cv2.addWeighted(image, 1, colored_mask, alpha, 0)
return overlay
def extract_garment(image, mask):
"""Extract garment using the mask (black background)."""
# Ensure mask is 3-channel
mask_3ch = np.stack([mask] * 3, axis=-1)
# Apply mask to image
extracted = image * mask_3ch
return extracted
def segment_garment(image, threshold=0.5, show_overlay=True):
"""Main segmentation function."""
# Load model
model = load_model()
# Preprocess
image_tensor, original_size, original_image = preprocess_image(image)
image_tensor = image_tensor.to(DEVICE)
# Inference
with torch.no_grad():
mask_logits = model(image_tensor)
# Postprocess
mask_binary, mask_prob = postprocess_mask(mask_logits, original_size, threshold)
# Resize probability map for visualization
mask_prob_resized = cv2.resize(mask_prob, (original_size[1], original_size[0]))
# Create visualizations
if show_overlay:
overlay = create_overlay(original_image, mask_binary)
extracted = extract_garment(original_image, mask_binary)
return overlay, extracted, (mask_binary * 255).astype(np.uint8)
else:
extracted = extract_garment(original_image, mask_binary)
return (mask_binary * 255).astype(np.uint8), extracted, overlay
# Custom CSS for better styling
custom_css = """
#title {
text-align: center;
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-size: 3em;
font-weight: bold;
margin-bottom: 0.5em;
}
#description {
text-align: center;
font-size: 1.2em;
color: #666;
margin-bottom: 2em;
}
#model-info {
background: #f8f9fa;
padding: 1.5em;
border-radius: 10px;
margin: 1em 0;
}
.performance-badge {
background: #28a745;
color: white;
padding: 0.3em 0.8em;
border-radius: 15px;
font-weight: bold;
display: inline-block;
margin: 0.2em;
}
footer {
text-align: center;
margin-top: 2em;
padding: 1em;
color: #888;
}
"""
# Create Gradio Interface
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
# Header
gr.Markdown("<h1 id='title'>👗 Garment Segmentation</h1>")
gr.Markdown(
"<p id='description'>AI-powered garment extraction for fashion and virtual try-on applications</p>"
)
# Model Information
with gr.Accordion("📊 Model Information", open=False):
gr.Markdown("""
<div id='model-info'>
### Architecture
- **Model**: U-Net with ResNet50 encoder
- **Input Size**: 768 × 768 pixels
- **Training Dataset**: DeepFashion2
- **Performance**: <span class='performance-badge'>Val IoU: 89.64%</span>
### Key Features
- 🎯 High-precision garment segmentation
- ⚡ Fast inference (GPU-accelerated)
- 🎨 Multiple visualization options
- 🔧 Adjustable confidence threshold
### Use Cases
- Virtual try-on applications
- Fashion e-commerce product editing
- Garment dataset preprocessing
- Clothing item extraction and isolation
</div>
""")
# Main Interface
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
label="📤 Upload Image",
type="pil",
height=400
)
threshold = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.5,
step=0.05,
label="🎚️ Confidence Threshold",
info="Adjust to refine the segmentation mask"
)
submit_btn = gr.Button("🚀 Segment Garment", variant="primary", size="lg")
gr.Markdown("### 💡 Tips:")
gr.Markdown("""
- Upload clear photos with visible garments
- Works best with upper-body clothing
- Adjust threshold if mask is too loose/tight
- Try different angles for best results
""")
with gr.Column(scale=2):
gr.Markdown("### 📊 Results")
with gr.Row():
output_overlay = gr.Image(
label="🎨 Overlay (Mask + Original)",
height=300
)
output_extracted = gr.Image(
label="✂️ Extracted Garment",
height=300
)
output_mask = gr.Image(
label="🎭 Binary Mask",
height=300
)
# Examples
gr.Markdown("### 🖼️ Try These Examples")
gr.Examples(
examples=[
["examples/fashion1.jpg", 0.5],
["examples/fashion2.jpg", 0.5],
["examples/fashion3.jpg", 0.5],
],
inputs=[input_image, threshold],
outputs=[output_overlay, output_extracted, output_mask],
fn=segment_garment,
cache_examples=False,
)
# Event handlers
submit_btn.click(
fn=segment_garment,
inputs=[input_image, threshold],
outputs=[output_overlay, output_extracted, output_mask]
)
# Auto-run on image upload
input_image.change(
fn=segment_garment,
inputs=[input_image, threshold],
outputs=[output_overlay, output_extracted, output_mask]
)
# Footer
gr.Markdown("""
<footer>
<hr>
<p>
Built with ❤️ using <a href="https://gradio.app">Gradio</a> |
Model: <a href="https://huggingface.co/zyuzuguldu/garment-segmentation-unet-resnet50">garment-segmentation-unet-resnet50</a> |
<a href="https://github.com/zyuzuguldu">GitHub</a>
</p>
</footer>
""")
# Launch the app
if __name__ == "__main__":
demo.launch()