Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation | |
| import numpy as np | |
| from diffusers import StableDiffusionInpaintPipeline | |
| import warnings | |
| import os | |
| warnings.filterwarnings("ignore") | |
| # Force CPU usage to avoid GPU issues on Hugging Face Spaces | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| torch.set_default_dtype(torch.float32) | |
| # Global variables for models | |
| processor = None | |
| model = None | |
| pipe = None | |
| def get_device(): | |
| """Safely determine the best available device""" | |
| try: | |
| # Force CPU for stability on HF Spaces | |
| return "cpu" | |
| except: | |
| return "cpu" | |
| def load_models(): | |
| """Load models with CPU-only configuration""" | |
| global processor, model, pipe | |
| try: | |
| print("Loading segmentation model...") | |
| processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes") | |
| model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes") | |
| # Ensure segmentation model is on CPU | |
| model = model.to("cpu") | |
| model.eval() | |
| print("Segmentation model loaded successfully!") | |
| print("Loading Stable Diffusion inpainting model...") | |
| # Load with explicit CPU configuration | |
| pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-2-inpainting", | |
| torch_dtype=torch.float32, # Use float32 for CPU | |
| safety_checker=None, | |
| requires_safety_checker=False, | |
| use_safetensors=True | |
| ) | |
| # Explicitly move all components to CPU | |
| pipe = pipe.to("cpu") | |
| # Enable memory efficiency | |
| if hasattr(pipe, 'enable_attention_slicing'): | |
| pipe.enable_attention_slicing() | |
| # Set to eval mode | |
| pipe.unet.eval() | |
| pipe.vae.eval() | |
| if hasattr(pipe, 'text_encoder'): | |
| pipe.text_encoder.eval() | |
| print("Stable Diffusion model loaded successfully on CPU!") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading models: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def segment_clothes(human_image): | |
| """Segment clothing from human image with CPU-only operations""" | |
| try: | |
| # Resize image if too large | |
| original_size = human_image.size | |
| if human_image.size[0] > 512 or human_image.size[1] > 512: | |
| human_image = human_image.resize((512, 512), Image.Resampling.LANCZOS) | |
| # Process human image for segmentation | |
| inputs = processor(images=human_image, return_tensors="pt") | |
| # Ensure inputs are on CPU | |
| for key in inputs: | |
| if torch.is_tensor(inputs[key]): | |
| inputs[key] = inputs[key].to("cpu") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits.cpu() | |
| upsampled_logits = torch.nn.functional.interpolate( | |
| logits, | |
| size=human_image.size[::-1], | |
| mode="bilinear", | |
| align_corners=False | |
| ) | |
| pred_seg = upsampled_logits.argmax(dim=1)[0].numpy() | |
| # Create mask for clothes | |
| clothes_labels = [4, 5, 6, 7, 8, 9, 10] | |
| clothes_mask = np.isin(pred_seg, clothes_labels).astype(np.uint8) * 255 | |
| # If no clothes detected, create a default mask | |
| if np.sum(clothes_mask) < 100: | |
| print("Creating default upper body mask") | |
| mask = np.zeros_like(pred_seg, dtype=np.uint8) | |
| h, w = mask.shape | |
| # Upper body region | |
| mask[h//4:3*h//4, w//3:2*w//3] = 255 | |
| clothes_mask = mask | |
| # Resize back to original size | |
| mask_image = Image.fromarray(clothes_mask) | |
| if original_size != mask_image.size: | |
| mask_image = mask_image.resize(original_size, Image.Resampling.LANCZOS) | |
| return mask_image | |
| except Exception as e: | |
| print(f"Error in segmentation: {str(e)}") | |
| # Return a default center mask | |
| h, w = human_image.size[::-1] | |
| mask = np.zeros((h, w), dtype=np.uint8) | |
| mask[h//4:3*h//4, w//3:2*w//3] = 255 | |
| return Image.fromarray(mask) | |
| def try_on_cloth(human_image, cloth_image, progress=gr.Progress()): | |
| """Main function for virtual try-on with CPU-safe operations""" | |
| if human_image is None or cloth_image is None: | |
| return None, "Please upload both human and cloth images." | |
| if processor is None or model is None or pipe is None: | |
| return None, "Models not loaded. Please refresh the page and try again." | |
| try: | |
| progress(0.1, desc="Processing images...") | |
| # Ensure images are PIL Images | |
| if isinstance(human_image, np.ndarray): | |
| human_image = Image.fromarray(human_image) | |
| if isinstance(cloth_image, np.ndarray): | |
| cloth_image = Image.fromarray(cloth_image) | |
| # Convert to RGB | |
| if human_image.mode != 'RGB': | |
| human_image = human_image.convert('RGB') | |
| if cloth_image.mode != 'RGB': | |
| cloth_image = cloth_image.convert('RGB') | |
| # Resize for processing | |
| target_size = (512, 512) | |
| human_image = human_image.resize(target_size, Image.Resampling.LANCZOS) | |
| cloth_image = cloth_image.resize(target_size, Image.Resampling.LANCZOS) | |
| progress(0.3, desc="Generating clothing mask...") | |
| # Generate mask | |
| mask = segment_clothes(human_image) | |
| progress(0.6, desc="Generating try-on result (this may take a few minutes on CPU)...") | |
| # Prepare for inpainting | |
| prompt = "a person wearing the clothing, realistic, high quality, natural lighting" | |
| negative_prompt = "blurry, low quality, distorted, deformed, extra limbs" | |
| # Create CPU generator | |
| generator = torch.Generator(device='cpu').manual_seed(42) | |
| # Generate with CPU-optimized settings | |
| with torch.no_grad(): | |
| result = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| image=human_image, | |
| mask_image=mask, | |
| num_inference_steps=15, # Reduced for CPU | |
| strength=0.75, | |
| guidance_scale=7.0, | |
| generator=generator | |
| ).images[0] | |
| progress(1.0, desc="Complete!") | |
| return result, "Try-on completed successfully! (Processed on CPU)" | |
| except Exception as e: | |
| error_msg = f"Error during try-on: {str(e)}" | |
| print(error_msg) | |
| import traceback | |
| traceback.print_exc() | |
| # Attempt simple fallback | |
| try: | |
| progress(0.8, desc="Attempting simple blend fallback...") | |
| mask_array = np.array(mask) / 255.0 | |
| cloth_resized = cloth_image.resize(human_image.size) | |
| human_array = np.array(human_image).astype(np.float32) | |
| cloth_array = np.array(cloth_resized).astype(np.float32) | |
| mask_3d = np.stack([mask_array] * 3, axis=2) | |
| result_array = human_array * (1 - mask_3d) + cloth_array * mask_3d | |
| result = Image.fromarray(result_array.astype(np.uint8)) | |
| return result, "Used simple blending due to processing error." | |
| except: | |
| return None, error_msg | |
| # Initialize models | |
| print("Initializing models for CPU processing...") | |
| models_loaded = load_models() | |
| # Gradio interface | |
| with gr.Blocks(title="Virtual Cloth Try-On AI", theme=gr.themes.Default()) as interface: | |
| gr.Markdown(""" | |
| # 🧥 Virtual Cloth Try-On AI (CPU Version) | |
| Upload a photo of a person and a clothing item to see how the outfit would look! | |
| **⚠️ Note: This app runs on CPU, so processing will take 2-5 minutes per image.** | |
| **Instructions:** | |
| 1. Upload a clear photo of a person (front-facing works best) | |
| 2. Upload an image of the clothing item you want to try on | |
| 3. Click "Generate Try-On" and be patient - CPU processing is slow but works! | |
| """) | |
| if not models_loaded: | |
| gr.Markdown("❌ **Models failed to load. Please refresh the page.**") | |
| else: | |
| gr.Markdown("✅ **Models loaded successfully! Ready for try-on.**") | |
| with gr.Row(): | |
| with gr.Column(): | |
| human_input = gr.Image( | |
| type="pil", | |
| label="👤 Human Photo" | |
| ) | |
| cloth_input = gr.Image( | |
| type="pil", | |
| label="👕 Clothing Item" | |
| ) | |
| with gr.Column(): | |
| result_output = gr.Image( | |
| type="pil", | |
| label="✨ Try-On Result" | |
| ) | |
| status_output = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| placeholder="Upload images and click 'Generate Try-On'" | |
| ) | |
| generate_btn = gr.Button( | |
| "🎨 Generate Try-On (Takes 2-5 minutes)", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| generate_btn.click( | |
| fn=try_on_cloth, | |
| inputs=[human_input, cloth_input], | |
| outputs=[result_output, status_output], | |
| show_progress=True | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **Tips for better results:** | |
| - Use clear, high-resolution images with good lighting | |
| - Person should be facing forward with visible torso | |
| - Clothing items should be clearly visible and unfolded | |
| - Simple backgrounds work better than busy ones | |
| - Be patient - CPU processing takes time but produces good results! | |
| **Expected processing time: 2-5 minutes per try-on** | |
| """) | |
| if __name__ == "__main__": | |
| interface.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) |