Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| import segmentation_models_pytorch as smp | |
| from safetensors.torch import load_file | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| # Configuration | |
| MODEL_REPO_ID = "zyuzuguldu/garment-segmentation-unet-resnet50" | |
| INPUT_SIZE = 768 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Cache the model globally | |
| model = None | |
| def load_model(): | |
| """Load the segmentation model from HuggingFace Hub.""" | |
| global model | |
| if model is None: | |
| print("📥 Downloading model from HuggingFace Hub...") | |
| model_path = hf_hub_download( | |
| repo_id=MODEL_REPO_ID, | |
| filename="model.safetensors" | |
| ) | |
| print("🔨 Building model architecture...") | |
| model = smp.Unet( | |
| encoder_name="resnet50", | |
| classes=1, | |
| activation=None, | |
| decoder_channels=(256, 128, 64, 32, 16) | |
| ) | |
| print("⚡ Loading weights...") | |
| state_dict = load_file(model_path) | |
| model.load_state_dict(state_dict) | |
| model.to(DEVICE) | |
| model.eval() | |
| print("✅ Model loaded successfully!") | |
| return model | |
| def preprocess_image(image): | |
| """Preprocess image for model inference.""" | |
| # Convert PIL to numpy | |
| if isinstance(image, Image.Image): | |
| image = np.array(image) | |
| # Ensure RGB | |
| if len(image.shape) == 2: | |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
| elif image.shape[2] == 4: | |
| image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) | |
| # Store original size | |
| original_size = image.shape[:2] | |
| # Resize to model input size | |
| image_resized = cv2.resize(image, (INPUT_SIZE, INPUT_SIZE)) | |
| # Normalize (ImageNet stats) | |
| mean = np.array([0.485, 0.456, 0.406]) | |
| std = np.array([0.229, 0.224, 0.225]) | |
| image_normalized = (image_resized / 255.0 - mean) / std | |
| # Convert to tensor: (H, W, C) -> (1, C, H, W) | |
| image_tensor = torch.from_numpy(image_normalized).float().permute(2, 0, 1).unsqueeze(0) | |
| return image_tensor, original_size, image | |
| def postprocess_mask(mask_logits, original_size, threshold=0.5): | |
| """Postprocess model output to binary mask.""" | |
| # Apply sigmoid and threshold | |
| mask_prob = torch.sigmoid(mask_logits).squeeze().cpu().numpy() | |
| mask_binary = (mask_prob > threshold).astype(np.uint8) | |
| # Resize back to original size | |
| mask_resized = cv2.resize(mask_binary, (original_size[1], original_size[0]), | |
| interpolation=cv2.INTER_NEAREST) | |
| return mask_resized, mask_prob | |
| def create_overlay(image, mask, alpha=0.6): | |
| """Create an overlay visualization of mask on image.""" | |
| # Create colored mask (cyan/turquoise color) | |
| colored_mask = np.zeros_like(image) | |
| colored_mask[:, :, 0] = mask * 0 # Red channel | |
| colored_mask[:, :, 1] = mask * 255 # Green channel | |
| colored_mask[:, :, 2] = mask * 255 # Blue channel | |
| # Blend with original image | |
| overlay = cv2.addWeighted(image, 1, colored_mask, alpha, 0) | |
| return overlay | |
| def extract_garment(image, mask): | |
| """Extract garment using the mask (black background).""" | |
| # Ensure mask is 3-channel | |
| mask_3ch = np.stack([mask] * 3, axis=-1) | |
| # Apply mask to image | |
| extracted = image * mask_3ch | |
| return extracted | |
| def segment_garment(image, threshold=0.5, show_overlay=True): | |
| """Main segmentation function.""" | |
| # Load model | |
| model = load_model() | |
| # Preprocess | |
| image_tensor, original_size, original_image = preprocess_image(image) | |
| image_tensor = image_tensor.to(DEVICE) | |
| # Inference | |
| with torch.no_grad(): | |
| mask_logits = model(image_tensor) | |
| # Postprocess | |
| mask_binary, mask_prob = postprocess_mask(mask_logits, original_size, threshold) | |
| # Resize probability map for visualization | |
| mask_prob_resized = cv2.resize(mask_prob, (original_size[1], original_size[0])) | |
| # Create visualizations | |
| if show_overlay: | |
| overlay = create_overlay(original_image, mask_binary) | |
| extracted = extract_garment(original_image, mask_binary) | |
| return overlay, extracted, (mask_binary * 255).astype(np.uint8) | |
| else: | |
| extracted = extract_garment(original_image, mask_binary) | |
| return (mask_binary * 255).astype(np.uint8), extracted, overlay | |
| # Custom CSS for better styling | |
| custom_css = """ | |
| #title { | |
| text-align: center; | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| font-size: 3em; | |
| font-weight: bold; | |
| margin-bottom: 0.5em; | |
| } | |
| #description { | |
| text-align: center; | |
| font-size: 1.2em; | |
| color: #666; | |
| margin-bottom: 2em; | |
| } | |
| #model-info { | |
| background: #f8f9fa; | |
| padding: 1.5em; | |
| border-radius: 10px; | |
| margin: 1em 0; | |
| } | |
| .performance-badge { | |
| background: #28a745; | |
| color: white; | |
| padding: 0.3em 0.8em; | |
| border-radius: 15px; | |
| font-weight: bold; | |
| display: inline-block; | |
| margin: 0.2em; | |
| } | |
| footer { | |
| text-align: center; | |
| margin-top: 2em; | |
| padding: 1em; | |
| color: #888; | |
| } | |
| """ | |
| # Create Gradio Interface | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: | |
| # Header | |
| gr.Markdown("<h1 id='title'>👗 Garment Segmentation</h1>") | |
| gr.Markdown( | |
| "<p id='description'>AI-powered garment extraction for fashion and virtual try-on applications</p>" | |
| ) | |
| # Model Information | |
| with gr.Accordion("📊 Model Information", open=False): | |
| gr.Markdown(""" | |
| <div id='model-info'> | |
| ### Architecture | |
| - **Model**: U-Net with ResNet50 encoder | |
| - **Input Size**: 768 × 768 pixels | |
| - **Training Dataset**: DeepFashion2 | |
| - **Performance**: <span class='performance-badge'>Val IoU: 89.64%</span> | |
| ### Key Features | |
| - 🎯 High-precision garment segmentation | |
| - ⚡ Fast inference (GPU-accelerated) | |
| - 🎨 Multiple visualization options | |
| - 🔧 Adjustable confidence threshold | |
| ### Use Cases | |
| - Virtual try-on applications | |
| - Fashion e-commerce product editing | |
| - Garment dataset preprocessing | |
| - Clothing item extraction and isolation | |
| </div> | |
| """) | |
| # Main Interface | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| label="📤 Upload Image", | |
| type="pil", | |
| height=400 | |
| ) | |
| threshold = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.5, | |
| step=0.05, | |
| label="🎚️ Confidence Threshold", | |
| info="Adjust to refine the segmentation mask" | |
| ) | |
| submit_btn = gr.Button("🚀 Segment Garment", variant="primary", size="lg") | |
| gr.Markdown("### 💡 Tips:") | |
| gr.Markdown(""" | |
| - Upload clear photos with visible garments | |
| - Works best with upper-body clothing | |
| - Adjust threshold if mask is too loose/tight | |
| - Try different angles for best results | |
| """) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 📊 Results") | |
| with gr.Row(): | |
| output_overlay = gr.Image( | |
| label="🎨 Overlay (Mask + Original)", | |
| height=300 | |
| ) | |
| output_extracted = gr.Image( | |
| label="✂️ Extracted Garment", | |
| height=300 | |
| ) | |
| output_mask = gr.Image( | |
| label="🎭 Binary Mask", | |
| height=300 | |
| ) | |
| # Examples | |
| gr.Markdown("### 🖼️ Try These Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/fashion1.jpg", 0.5], | |
| ["examples/fashion2.jpg", 0.5], | |
| ["examples/fashion3.jpg", 0.5], | |
| ], | |
| inputs=[input_image, threshold], | |
| outputs=[output_overlay, output_extracted, output_mask], | |
| fn=segment_garment, | |
| cache_examples=False, | |
| ) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=segment_garment, | |
| inputs=[input_image, threshold], | |
| outputs=[output_overlay, output_extracted, output_mask] | |
| ) | |
| # Auto-run on image upload | |
| input_image.change( | |
| fn=segment_garment, | |
| inputs=[input_image, threshold], | |
| outputs=[output_overlay, output_extracted, output_mask] | |
| ) | |
| # Footer | |
| gr.Markdown(""" | |
| <footer> | |
| <hr> | |
| <p> | |
| Built with ❤️ using <a href="https://gradio.app">Gradio</a> | | |
| Model: <a href="https://huggingface.co/zyuzuguldu/garment-segmentation-unet-resnet50">garment-segmentation-unet-resnet50</a> | | |
| <a href="https://github.com/zyuzuguldu">GitHub</a> | |
| </p> | |
| </footer> | |
| """) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() | |