import colorsys import os # ZeroGPU: must import before any CUDA-related packages try: import spaces GPU_DECORATOR = spaces.GPU except ImportError: GPU_DECORATOR = lambda func: func import gradio as gr import matplotlib.colors as mcolors import numpy as np import torch from PIL import Image, ImageDraw, ImageFont from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor # ----------------- CONFIG ----------------- # ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets") MODEL_ID = "fashn-ai/fashn-human-parser" LABELS_TO_IDS = { "Background": 0, "Face": 1, "Hair": 2, "Top": 3, "Dress": 4, "Skirt": 5, "Pants": 6, "Belt": 7, "Bag": 8, "Hat": 9, "Scarf": 10, "Glasses": 11, "Arms": 12, "Hands": 13, "Legs": 14, "Feet": 15, "Torso": 16, "Jewelry": 17, } IDS_TO_LABELS = {v: k for k, v in LABELS_TO_IDS.items()} # ----------------- HELPERS ----------------- # def constrain_image_size(img: Image.Image, max_width: int = 768, max_height: int = 1152) -> Image.Image: """ Constrains image to maximum dimensions while maintaining aspect ratio. Returns new resized image if constraints exceeded, otherwise returns original. Caller is responsible for closing the returned image if it differs from input. """ width, height = img.size # Check if resize needed if width <= max_width and height <= max_height: return img # Calculate scaling factor (whichever constraint is hit first) width_scale = max_width / width height_scale = max_height / height scale = min(width_scale, height_scale) # Calculate new dimensions new_width = int(width * scale) new_height = int(height * scale) # Resize using high-quality Lanczos resampling return img.resize((new_width, new_height), Image.Resampling.LANCZOS) def get_palette(num_cls: int) -> list[int]: palette = [0] * (256 * 3) palette[0:3] = [0, 0, 0] for j in range(1, num_cls): hue = (j - 1) / (num_cls - 1) saturation = 1.0 value = 1.0 if j % 2 == 0 else 0.5 rgb = colorsys.hsv_to_rgb(hue, saturation, value) r, g, b = [int(x * 255) for x in rgb] palette[j * 3 : j * 3 + 3] = [r, g, b] return palette def create_colormap(palette: list[int]) -> mcolors.ListedColormap: colormap = np.array(palette).reshape(-1, 3) / 255.0 return mcolors.ListedColormap(colormap) def visualize_mask_with_overlay(img: Image.Image, mask: np.ndarray, alpha: float = 0.5) -> Image.Image: # Convert to RGB if needed (creates temporary image) rgb_img = img.convert("RGB") try: img_np = np.array(rgb_img) finally: # Close converted image if it's different from original if rgb_img is not img: rgb_img.close() num_cls = len(LABELS_TO_IDS) palette = get_palette(num_cls) colormap = create_colormap(palette) overlay = np.zeros((*mask.shape, 3), dtype=np.uint8) for label, idx in LABELS_TO_IDS.items(): if idx != 0: overlay[mask == idx] = np.array(colormap(idx)[:3]) * 255 blended = Image.fromarray(np.uint8(img_np * (1 - alpha) + overlay * alpha)) return blended def create_legend_image() -> Image.Image: num_cls = len(LABELS_TO_IDS) palette = get_palette(num_cls) # 2 columns layout scale = 1 rows_per_col = (num_cls + 1) // 2 col_width = 200 * scale row_height = 35 * scale legend_width = col_width * 2 legend_height = rows_per_col * row_height + 20 * scale # Use context manager for proper cleanup legend = Image.new("RGB", (legend_width, legend_height), "white") draw = ImageDraw.Draw(legend) # Cross-platform font loading font = None font_paths = [ "/System/Library/Fonts/Helvetica.ttc", # macOS "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", # Linux "/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf", # Linux ] for font_path in font_paths: try: font = ImageFont.truetype(font_path, 20 * scale) break except (OSError, IOError): continue if font is None: font = ImageFont.load_default() box_size = 28 * scale for idx, label in IDS_TO_LABELS.items(): col = idx // rows_per_col row = idx % rows_per_col x = col * col_width + 10 * scale y = row * row_height + 10 * scale color = tuple(palette[idx * 3 : idx * 3 + 3]) draw.rectangle([x, y, x + box_size, y + box_size], fill=color, outline="black", width=2) draw.text((x + box_size + 10 * scale, y + 5 * scale), f"{idx}: {label}", fill="black", font=font) return legend # ----------------- MODEL ----------------- # # Global state (lazy loaded for ZeroGPU compatibility) _model = None _processor = None _device = None def get_model(): """Lazy-load model on first use (ensures GPU available on ZeroGPU).""" global _model, _processor, _device if _model is None: _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Enable TF32 for Ampere+ GPUs if _device.type == "cuda" and torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True print(f"Loading model on {_device}...") _processor = SegformerImageProcessor.from_pretrained(MODEL_ID) _model = SegformerForSemanticSegmentation.from_pretrained(MODEL_ID) _model.eval() _model.to(_device) print(f"Model loaded on {_device}!") return _model, _processor, _device @GPU_DECORATOR def segment(image: Image.Image) -> tuple[Image.Image, Image.Image]: if image is None: raise gr.Error("Please upload an image") # Lazy-load model (ensures GPU available on ZeroGPU) model, processor, device = get_model() # Constrain output size (max 768w or 1152h, whichever hits first) constrained_image = constrain_image_size(image, max_width=768, max_height=1152) image_was_resized = constrained_image is not image try: inputs = processor(images=constrained_image, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits upsampled = torch.nn.functional.interpolate( logits, size=(constrained_image.height, constrained_image.width), mode="bilinear", align_corners=False, ) mask = upsampled.argmax(dim=1).squeeze(0).cpu().numpy() mask_image = Image.fromarray(mask.astype("uint8")) blended_image = visualize_mask_with_overlay(constrained_image, mask, alpha=0.5) return blended_image, mask_image finally: # Clean up resized image if one was created if image_was_resized: constrained_image.close() # ----------------- UI ----------------- # # Pre-generate legend with proper cleanup legend_path = os.path.join(ASSETS_DIR, "legend.png") legend_img = create_legend_image() try: legend_img.save(legend_path) finally: legend_img.close() # Load examples examples_dir = os.path.join(ASSETS_DIR, "examples") example_images = sorted([ os.path.join(examples_dir, img) for img in os.listdir(examples_dir) if img.lower().endswith((".png", ".jpg", ".jpeg", ".webp")) ]) if os.path.exists(examples_dir) else [] # Custom CSS CUSTOM_CSS = """ .contain img { object-fit: contain !important; } """ # Load HTML content with open(os.path.join(os.path.dirname(__file__), "banner.html"), "r") as f: banner_html = f.read() with open(os.path.join(os.path.dirname(__file__), "tips.html"), "r") as f: tips_html = f.read() # Build UI with gr.Blocks() as demo: # Header gr.HTML(banner_html) gr.HTML(tips_html) with gr.Row(equal_height=False): # Left column: Input with gr.Column(scale=1): input_image = gr.Image( label="Input Image", type="pil", sources=["upload", "clipboard"], elem_classes=["contain"], height=864, width=576, ) run_button = gr.Button("Run", variant="primary", size="lg") if example_images: gr.Examples( examples=example_images, inputs=input_image, examples_per_page=8, label="Examples", ) # Legend below examples with gr.Accordion("Label Legend", open=True): gr.Image( value=legend_path, label=None, show_label=False, interactive=False, ) # Right column: Results with gr.Column(scale=1): result_image = gr.Image( label="Segmentation Overlay", type="pil", interactive=False, elem_classes=["contain"], height=864, width=576, ) mask_image = gr.Image( label="Segmentation Mask", type="pil", interactive=False, elem_classes=["contain"], height=864, width=576, ) # Event handler run_button.click( fn=segment, inputs=[input_image], outputs=[result_image, mask_image], ) # Configure queue for ZeroGPU demo.queue(default_concurrency_limit=1, max_size=30) if __name__ == "__main__": demo.launch( share=False, css=CUSTOM_CSS, css_paths=None, )