Try-On / app.py
Amit Shamsundar
GPU error
9285254
Raw
History Blame Contribute Delete
10 kB
import gradio as gr
from PIL import Image
import torch
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
import numpy as np
from diffusers import StableDiffusionInpaintPipeline
import warnings
import os
warnings.filterwarnings("ignore")
# Force CPU usage to avoid GPU issues on Hugging Face Spaces
os.environ["CUDA_VISIBLE_DEVICES"] = ""
torch.set_default_dtype(torch.float32)
# Global variables for models
processor = None
model = None
pipe = None
def get_device():
"""Safely determine the best available device"""
try:
# Force CPU for stability on HF Spaces
return "cpu"
except:
return "cpu"
def load_models():
"""Load models with CPU-only configuration"""
global processor, model, pipe
try:
print("Loading segmentation model...")
processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
# Ensure segmentation model is on CPU
model = model.to("cpu")
model.eval()
print("Segmentation model loaded successfully!")
print("Loading Stable Diffusion inpainting model...")
# Load with explicit CPU configuration
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float32, # Use float32 for CPU
safety_checker=None,
requires_safety_checker=False,
use_safetensors=True
)
# Explicitly move all components to CPU
pipe = pipe.to("cpu")
# Enable memory efficiency
if hasattr(pipe, 'enable_attention_slicing'):
pipe.enable_attention_slicing()
# Set to eval mode
pipe.unet.eval()
pipe.vae.eval()
if hasattr(pipe, 'text_encoder'):
pipe.text_encoder.eval()
print("Stable Diffusion model loaded successfully on CPU!")
return True
except Exception as e:
print(f"Error loading models: {str(e)}")
import traceback
traceback.print_exc()
return False
def segment_clothes(human_image):
"""Segment clothing from human image with CPU-only operations"""
try:
# Resize image if too large
original_size = human_image.size
if human_image.size[0] > 512 or human_image.size[1] > 512:
human_image = human_image.resize((512, 512), Image.Resampling.LANCZOS)
# Process human image for segmentation
inputs = processor(images=human_image, return_tensors="pt")
# Ensure inputs are on CPU
for key in inputs:
if torch.is_tensor(inputs[key]):
inputs[key] = inputs[key].to("cpu")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = torch.nn.functional.interpolate(
logits,
size=human_image.size[::-1],
mode="bilinear",
align_corners=False
)
pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
# Create mask for clothes
clothes_labels = [4, 5, 6, 7, 8, 9, 10]
clothes_mask = np.isin(pred_seg, clothes_labels).astype(np.uint8) * 255
# If no clothes detected, create a default mask
if np.sum(clothes_mask) < 100:
print("Creating default upper body mask")
mask = np.zeros_like(pred_seg, dtype=np.uint8)
h, w = mask.shape
# Upper body region
mask[h//4:3*h//4, w//3:2*w//3] = 255
clothes_mask = mask
# Resize back to original size
mask_image = Image.fromarray(clothes_mask)
if original_size != mask_image.size:
mask_image = mask_image.resize(original_size, Image.Resampling.LANCZOS)
return mask_image
except Exception as e:
print(f"Error in segmentation: {str(e)}")
# Return a default center mask
h, w = human_image.size[::-1]
mask = np.zeros((h, w), dtype=np.uint8)
mask[h//4:3*h//4, w//3:2*w//3] = 255
return Image.fromarray(mask)
def try_on_cloth(human_image, cloth_image, progress=gr.Progress()):
"""Main function for virtual try-on with CPU-safe operations"""
if human_image is None or cloth_image is None:
return None, "Please upload both human and cloth images."
if processor is None or model is None or pipe is None:
return None, "Models not loaded. Please refresh the page and try again."
try:
progress(0.1, desc="Processing images...")
# Ensure images are PIL Images
if isinstance(human_image, np.ndarray):
human_image = Image.fromarray(human_image)
if isinstance(cloth_image, np.ndarray):
cloth_image = Image.fromarray(cloth_image)
# Convert to RGB
if human_image.mode != 'RGB':
human_image = human_image.convert('RGB')
if cloth_image.mode != 'RGB':
cloth_image = cloth_image.convert('RGB')
# Resize for processing
target_size = (512, 512)
human_image = human_image.resize(target_size, Image.Resampling.LANCZOS)
cloth_image = cloth_image.resize(target_size, Image.Resampling.LANCZOS)
progress(0.3, desc="Generating clothing mask...")
# Generate mask
mask = segment_clothes(human_image)
progress(0.6, desc="Generating try-on result (this may take a few minutes on CPU)...")
# Prepare for inpainting
prompt = "a person wearing the clothing, realistic, high quality, natural lighting"
negative_prompt = "blurry, low quality, distorted, deformed, extra limbs"
# Create CPU generator
generator = torch.Generator(device='cpu').manual_seed(42)
# Generate with CPU-optimized settings
with torch.no_grad():
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=human_image,
mask_image=mask,
num_inference_steps=15, # Reduced for CPU
strength=0.75,
guidance_scale=7.0,
generator=generator
).images[0]
progress(1.0, desc="Complete!")
return result, "Try-on completed successfully! (Processed on CPU)"
except Exception as e:
error_msg = f"Error during try-on: {str(e)}"
print(error_msg)
import traceback
traceback.print_exc()
# Attempt simple fallback
try:
progress(0.8, desc="Attempting simple blend fallback...")
mask_array = np.array(mask) / 255.0
cloth_resized = cloth_image.resize(human_image.size)
human_array = np.array(human_image).astype(np.float32)
cloth_array = np.array(cloth_resized).astype(np.float32)
mask_3d = np.stack([mask_array] * 3, axis=2)
result_array = human_array * (1 - mask_3d) + cloth_array * mask_3d
result = Image.fromarray(result_array.astype(np.uint8))
return result, "Used simple blending due to processing error."
except:
return None, error_msg
# Initialize models
print("Initializing models for CPU processing...")
models_loaded = load_models()
# Gradio interface
with gr.Blocks(title="Virtual Cloth Try-On AI", theme=gr.themes.Default()) as interface:
gr.Markdown("""
# 🧥 Virtual Cloth Try-On AI (CPU Version)
Upload a photo of a person and a clothing item to see how the outfit would look!
**⚠️ Note: This app runs on CPU, so processing will take 2-5 minutes per image.**
**Instructions:**
1. Upload a clear photo of a person (front-facing works best)
2. Upload an image of the clothing item you want to try on
3. Click "Generate Try-On" and be patient - CPU processing is slow but works!
""")
if not models_loaded:
gr.Markdown("❌ **Models failed to load. Please refresh the page.**")
else:
gr.Markdown("✅ **Models loaded successfully! Ready for try-on.**")
with gr.Row():
with gr.Column():
human_input = gr.Image(
type="pil",
label="👤 Human Photo"
)
cloth_input = gr.Image(
type="pil",
label="👕 Clothing Item"
)
with gr.Column():
result_output = gr.Image(
type="pil",
label="✨ Try-On Result"
)
status_output = gr.Textbox(
label="Status",
interactive=False,
placeholder="Upload images and click 'Generate Try-On'"
)
generate_btn = gr.Button(
"🎨 Generate Try-On (Takes 2-5 minutes)",
variant="primary",
size="lg"
)
generate_btn.click(
fn=try_on_cloth,
inputs=[human_input, cloth_input],
outputs=[result_output, status_output],
show_progress=True
)
gr.Markdown("""
---
**Tips for better results:**
- Use clear, high-resolution images with good lighting
- Person should be facing forward with visible torso
- Clothing items should be clearly visible and unfolded
- Simple backgrounds work better than busy ones
- Be patient - CPU processing takes time but produces good results!
**Expected processing time: 2-5 minutes per try-on**
""")
if __name__ == "__main__":
interface.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)