Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import autocast | |
| from diffusers import StableDiffusionInpaintPipeline | |
| import gradio as gr | |
| import traceback | |
| import base64 | |
| from io import BytesIO | |
| import os | |
| # import sys | |
| import PIL | |
| import json | |
| import requests | |
| import logging | |
| import time | |
| import warnings | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import cv2 | |
| warnings.filterwarnings("ignore") | |
| # sys.path.insert(1, './parser') | |
| # from parser.schp_masker import * | |
| from parser.segformer_parser import SegformerParser | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger('clothquill') | |
| # Model paths | |
| SEGFORMER_MODEL = "mattmdjaga/segformer_b2_clothes" | |
| STABLE_DIFFUSION_MODEL = "stabilityai/stable-diffusion-2-inpainting" | |
| # Global variables for models | |
| parser = None | |
| model = None | |
| inpainter = None | |
| original_image = None # Store the original uploaded image | |
| # Color mapping for different clothing parts | |
| CLOTHING_COLORS = { | |
| 'Background': (0, 0, 0, 0), # Transparent | |
| 'Hat': (255, 0, 0, 128), # Red | |
| 'Hair': (0, 255, 0, 128), # Green | |
| 'Glove': (0, 0, 255, 128), # Blue | |
| 'Sunglasses': (255, 255, 0, 128), # Yellow | |
| 'Upper-clothes': (255, 0, 255, 128), # Magenta | |
| 'Dress': (0, 255, 255, 128), # Cyan | |
| 'Coat': (128, 0, 0, 128), # Dark Red | |
| 'Socks': (0, 128, 0, 128), # Dark Green | |
| 'Pants': (0, 0, 128, 128), # Dark Blue | |
| 'Jumpsuits': (128, 128, 0, 128), # Dark Yellow | |
| 'Scarf': (128, 0, 128, 128), # Dark Magenta | |
| 'Skirt': (0, 128, 128, 128), # Dark Cyan | |
| 'Face': (192, 192, 192, 128), # Light Gray | |
| 'Left-arm': (64, 64, 64, 128), # Dark Gray | |
| 'Right-arm': (64, 64, 64, 128), # Dark Gray | |
| 'Left-leg': (32, 32, 32, 128), # Very Dark Gray | |
| 'Right-leg': (32, 32, 32, 128), # Very Dark Gray | |
| 'Left-shoe': (16, 16, 16, 128), # Almost Black | |
| 'Right-shoe': (16, 16, 16, 128), # Almost Black | |
| } | |
| def get_device(): | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| logger.info("Using GPU") | |
| else: | |
| device = "cpu" | |
| logger.info("Using CPU") | |
| return device | |
| def init(): | |
| global parser | |
| global model | |
| global inpainter | |
| start_time = time.time() | |
| logger.info("Starting application initialization") | |
| try: | |
| device = get_device() | |
| # Check if models directory exists | |
| if not os.path.exists("models"): | |
| logger.info("Creating models directory...") | |
| from download_models import download_models | |
| download_models() | |
| # Initialize Segformer parser | |
| logger.info("Initializing Segformer parser...") | |
| parser = SegformerParser(SEGFORMER_MODEL) | |
| # Initialize Stable Diffusion model | |
| logger.info("Initializing Stable Diffusion model...") | |
| model = StableDiffusionInpaintPipeline.from_pretrained( | |
| STABLE_DIFFUSION_MODEL, | |
| safety_checker=None, | |
| revision="fp16" if device == "cuda" else None, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
| ).to(device) | |
| # Initialize inpainter | |
| logger.info("Initializing inpainter...") | |
| inpainter = ClothingInpainter(model=model, parser=parser) | |
| logger.info(f"Application initialized in {time.time() - start_time:.2f} seconds") | |
| except Exception as e: | |
| logger.error(f"Error initializing application: {str(e)}") | |
| raise e | |
| class ClothingInpainter: | |
| def __init__(self, model_path=None, model=None, parser=None): | |
| self.device = get_device() | |
| self.last_mask = None # Store the last generated mask | |
| self.original_image = None # Store the original image | |
| if model_path is None and model is None: | |
| raise ValueError('No model provided!') | |
| if model_path is not None: | |
| self.pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
| model_path, | |
| safety_checker=None, | |
| revision="fp16" if self.device == "cuda" else None, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
| ).to(self.device) | |
| else: | |
| self.pipe = model | |
| self.parser = parser | |
| def make_square(self, im, min_size=256, fill_color=(0, 0, 0, 0)): | |
| x, y = im.size | |
| size = max(min_size, x, y) | |
| new_im = PIL.Image.new('RGBA', (size, size), fill_color) | |
| new_im.paste(im, (int((size - x) / 2), int((size - y) / 2))) | |
| return new_im.convert('RGB') | |
| def unmake_square(self, init_im, op_im, min_size=256, rs_size=512): | |
| x, y = init_im.size | |
| size = max(min_size, x, y) | |
| factor = rs_size/size | |
| return op_im.crop((int((size-x) * factor / 2), int((size-y) * factor / 2),\ | |
| int((size+x) * factor / 2), int((size+y) * factor / 2))) | |
| def visualize_segmentation(self, image, masks, selected_parts=None): | |
| """Visualize segmentation with colored overlays for selected parts and gray for unselected.""" | |
| # Always use original image if available | |
| image_to_use = self.original_image if self.original_image is not None else image | |
| # Create a copy of the original image | |
| original_size = image_to_use.size | |
| vis_image = image_to_use.copy().convert('RGBA') | |
| # Create overlay at 512x512 | |
| overlay = Image.new('RGBA', (512, 512), (0, 0, 0, 0)) | |
| draw = ImageDraw.Draw(overlay) | |
| # Draw each mask with its corresponding color | |
| for part_name, mask in masks.items(): | |
| # Convert part name for color lookup | |
| color_key = part_name.replace('-', ' ').title().replace(' ', '-') | |
| is_selected = selected_parts and part_name in selected_parts | |
| # If selected, use color (with fallback). If unselected, use faint gray | |
| if is_selected: | |
| color = CLOTHING_COLORS.get(color_key, (255, 0, 255, 128)) # Default to magenta if no color found | |
| else: | |
| color = (180, 180, 180, 80) # Faint gray for unselected | |
| mask_array = np.array(mask) | |
| coords = np.where(mask_array > 0) | |
| for y, x in zip(coords[0], coords[1]): | |
| draw.point((x, y), fill=color) | |
| # Resize overlay to match original image size | |
| overlay = overlay.resize(original_size, Image.Resampling.LANCZOS) | |
| # Composite the overlay onto the original image | |
| vis_image = Image.alpha_composite(vis_image, overlay) | |
| return vis_image | |
| def inpaint(self, prompt, init_image, selected_parts=None, dilation_iterations=2) -> dict: | |
| image = self.make_square(init_image).resize((512,512)) | |
| if self.parser is not None: | |
| masks = self.parser.get_all_masks(image) | |
| masks = {k: v.resize((512,512)) for k, v in masks.items()} | |
| else: | |
| raise ValueError('Image Parser is Missing') | |
| logger.info(f'[generated required mask(s) at {time.time()}]') | |
| # Create combined mask for selected parts | |
| if selected_parts: | |
| combined_mask = Image.new('L', (512, 512), 0) | |
| for part in selected_parts: | |
| if part in masks: | |
| mask_array = np.array(masks[part]) | |
| kernel = np.ones((5,5), np.uint8) | |
| dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations) | |
| dilated_mask = Image.fromarray(dilated_mask) | |
| combined_mask = Image.composite( | |
| Image.new('L', (512, 512), 255), | |
| combined_mask, | |
| dilated_mask | |
| ) | |
| else: | |
| # If no parts selected, use all clothing parts | |
| combined_mask = Image.new('L', (512, 512), 0) | |
| for part, mask in masks.items(): | |
| if part in ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']: | |
| mask_array = np.array(mask) | |
| kernel = np.ones((5,5), np.uint8) | |
| dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations) | |
| dilated_mask = Image.fromarray(dilated_mask) | |
| combined_mask = Image.composite( | |
| Image.new('L', (512, 512), 255), | |
| combined_mask, | |
| dilated_mask | |
| ) | |
| # Run the model | |
| guidance_scale=7.5 | |
| num_samples = 3 | |
| with autocast("cuda"), torch.inference_mode(): | |
| images = self.pipe( | |
| num_inference_steps = 50, | |
| prompt=prompt['pos'], | |
| image=image, | |
| mask_image=combined_mask, | |
| guidance_scale=guidance_scale, | |
| num_images_per_prompt=num_samples, | |
| ).images | |
| images_output = [] | |
| for img in images: | |
| ch = PIL.Image.composite(img, image, combined_mask) | |
| fin_img = self.unmake_square(init_image, ch) | |
| images_output.append(fin_img) | |
| return images_output | |
| def process_segmentation(image, dilation_iterations=2): | |
| try: | |
| if image is None: | |
| raise gr.Error("Please upload an image") | |
| # Store original image | |
| inpainter.original_image = image.copy() | |
| # Create a processing copy at 512x512 | |
| proc_image = image.resize((512, 512), Image.Resampling.LANCZOS) | |
| # Get the main mask | |
| all_masks = inpainter.parser.get_all_masks(proc_image) | |
| if not all_masks: | |
| logger.error("No clothing detected in the image") | |
| raise gr.Error("No clothing detected in the image. Please try a different image.") | |
| inpainter.last_mask = all_masks | |
| # Only show main clothing parts for selection | |
| main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt'] | |
| masks = {k: v for k, v in all_masks.items() if k in main_parts} | |
| vis_image = inpainter.visualize_segmentation(image, masks, selected_parts=None) | |
| detected_parts = [k for k in masks.keys()] | |
| return vis_image, gr.update(choices=detected_parts, value=[]) | |
| except gr.Error as e: | |
| raise e | |
| except Exception as e: | |
| logger.error(f"Error processing segmentation: {str(e)}") | |
| raise gr.Error("Error processing the image. Please try a different image.") | |
| def update_dilation(image, selected_parts, dilation_iterations): | |
| try: | |
| if image is None or inpainter.last_mask is None: | |
| return image | |
| # Redilate all stored masks | |
| main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt'] | |
| masks = {} | |
| for part in main_parts: | |
| if part in inpainter.last_mask: | |
| mask_array = np.array(inpainter.last_mask[part]) | |
| kernel = np.ones((5,5), np.uint8) | |
| dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations) | |
| masks[part] = Image.fromarray(dilated_mask) | |
| # Use original image for visualization | |
| vis_image = inpainter.visualize_segmentation(inpainter.original_image, masks, selected_parts=selected_parts) | |
| return vis_image | |
| except Exception as e: | |
| logger.error(f"Error updating dilation: {str(e)}") | |
| return image | |
| def process_image(prompt, image, selected_parts, dilation_iterations): | |
| start_time = time.time() | |
| logger.info(f"Processing new request - Prompt: {prompt}, Image size: {image.size if image else 'None'}") | |
| try: | |
| if image is None: | |
| logger.error("No image provided") | |
| raise gr.Error("Please upload an image") | |
| if not prompt: | |
| logger.error("No prompt provided") | |
| raise gr.Error("Please enter a prompt") | |
| if not selected_parts: | |
| logger.error("No parts selected") | |
| raise gr.Error("Please select at least one clothing part to modify") | |
| prompt_dict = {'pos': prompt} | |
| logger.info("Starting inpainting process") | |
| # Generate inpainted images | |
| # Convert selected_parts to lowercase/dash format | |
| selected_parts = [p.lower() for p in selected_parts] | |
| images = inpainter.inpaint(prompt_dict, image, selected_parts, dilation_iterations) | |
| if not images: | |
| logger.error("Inpainting failed to produce results") | |
| raise gr.Error("Failed to generate images. Please try again.") | |
| logger.info(f"Request processed in {time.time() - start_time:.2f} seconds") | |
| return images | |
| except Exception as e: | |
| logger.error(f"Error processing image: {str(e)}") | |
| raise gr.Error(f"Error processing image: {str(e)}") | |
| def update_selected_parts(image, selected_parts, dilation_iterations): | |
| try: | |
| if image is None or inpainter.last_mask is None: | |
| return image | |
| main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt'] | |
| masks = {} | |
| for part in main_parts: | |
| if part in inpainter.last_mask: | |
| mask_array = np.array(inpainter.last_mask[part]) | |
| kernel = np.ones((5,5), np.uint8) | |
| dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations) | |
| masks[part] = Image.fromarray(dilated_mask) | |
| # Lowercase the selected_parts for comparison | |
| selected_parts = [p.lower() for p in selected_parts] if selected_parts else [] | |
| # Use original image for visualization | |
| vis_image = inpainter.visualize_segmentation(inpainter.original_image, masks, selected_parts=selected_parts) | |
| return vis_image | |
| except Exception as e: | |
| logger.error(f"Error updating selected parts: {str(e)}") | |
| return image | |
| # Initialize the model | |
| init() | |
| # Create Gradio interface | |
| with gr.Blocks(title="ClothQuill - AI Clothing Inpainting") as demo: | |
| gr.Markdown("# ClothQuill - AI Clothing Inpainting") | |
| gr.Markdown("Upload an image to see segmented clothing parts, then select parts to modify and describe your changes") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| type="pil", | |
| label="Upload Image", | |
| scale=1, # This ensures the image maintains its aspect ratio | |
| height=None # Allow dynamic height based on content | |
| ) | |
| dilation_slider = gr.Slider( | |
| minimum=0, | |
| maximum=5, | |
| value=2, | |
| step=1, | |
| label="Mask Dilation", | |
| info="Adjust the mask dilation to control the area of modification" | |
| ) | |
| selected_parts = gr.CheckboxGroup( | |
| choices=[], | |
| label="Select parts to modify", | |
| value=[] | |
| ) | |
| prompt = gr.Textbox( | |
| label="Describe the clothing you want to generate", | |
| placeholder="e.g., A stylish black leather jacket" | |
| ) | |
| generate_btn = gr.Button("Generate") | |
| with gr.Column(): | |
| gallery = gr.Gallery( | |
| label="Generated Results", | |
| show_label=False, | |
| columns=2, | |
| height=None, # Allow dynamic height | |
| object_fit="contain" # Maintain aspect ratio | |
| ) | |
| # Add event handler for image upload | |
| input_image.upload( | |
| fn=process_segmentation, | |
| inputs=[input_image, dilation_slider], | |
| outputs=[input_image, selected_parts] | |
| ) | |
| # Add event handler for dilation changes | |
| dilation_slider.change( | |
| fn=update_dilation, | |
| inputs=[input_image, selected_parts,dilation_slider], | |
| outputs=input_image | |
| ) | |
| # Add event handler for generation | |
| generate_btn.click( | |
| fn=process_image, | |
| inputs=[prompt, input_image, selected_parts, dilation_slider], | |
| outputs=gallery | |
| ) | |
| # Add event handler for part selection changes | |
| selected_parts.change( | |
| fn=update_selected_parts, | |
| inputs=[input_image, selected_parts, dilation_slider], | |
| outputs=input_image | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |