Spaces:
Running
Running
| import colorsys | |
| import os | |
| import gradio as gr | |
| import matplotlib.colors as mcolors | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor | |
| # Handle spaces.GPU decorator for local vs HuggingFace execution | |
| try: | |
| import spaces | |
| GPU_DECORATOR = spaces.GPU | |
| except ImportError: | |
| GPU_DECORATOR = lambda func: func | |
| # ----------------- CONFIG ----------------- # | |
| if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets") | |
| MODEL_ID = "fashn-ai/fashn-human-parser" | |
| LABELS_TO_IDS = { | |
| "Background": 0, | |
| "Face": 1, | |
| "Hair": 2, | |
| "Top": 3, | |
| "Dress": 4, | |
| "Skirt": 5, | |
| "Pants": 6, | |
| "Belt": 7, | |
| "Bag": 8, | |
| "Hat": 9, | |
| "Scarf": 10, | |
| "Glasses": 11, | |
| "Arms": 12, | |
| "Hands": 13, | |
| "Legs": 14, | |
| "Feet": 15, | |
| "Torso": 16, | |
| "Jewelry": 17, | |
| } | |
| IDS_TO_LABELS = {v: k for k, v in LABELS_TO_IDS.items()} | |
| # ----------------- HELPERS ----------------- # | |
| def constrain_image_size(img: Image.Image, max_width: int = 768, max_height: int = 1152) -> Image.Image: | |
| """ | |
| Constrains image to maximum dimensions while maintaining aspect ratio. | |
| Returns new resized image if constraints exceeded, otherwise returns original. | |
| Caller is responsible for closing the returned image if it differs from input. | |
| """ | |
| width, height = img.size | |
| # Check if resize needed | |
| if width <= max_width and height <= max_height: | |
| return img | |
| # Calculate scaling factor (whichever constraint is hit first) | |
| width_scale = max_width / width | |
| height_scale = max_height / height | |
| scale = min(width_scale, height_scale) | |
| # Calculate new dimensions | |
| new_width = int(width * scale) | |
| new_height = int(height * scale) | |
| # Resize using high-quality Lanczos resampling | |
| return img.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| def get_palette(num_cls: int) -> list[int]: | |
| palette = [0] * (256 * 3) | |
| palette[0:3] = [0, 0, 0] | |
| for j in range(1, num_cls): | |
| hue = (j - 1) / (num_cls - 1) | |
| saturation = 1.0 | |
| value = 1.0 if j % 2 == 0 else 0.5 | |
| rgb = colorsys.hsv_to_rgb(hue, saturation, value) | |
| r, g, b = [int(x * 255) for x in rgb] | |
| palette[j * 3 : j * 3 + 3] = [r, g, b] | |
| return palette | |
| def create_colormap(palette: list[int]) -> mcolors.ListedColormap: | |
| colormap = np.array(palette).reshape(-1, 3) / 255.0 | |
| return mcolors.ListedColormap(colormap) | |
| def visualize_mask_with_overlay(img: Image.Image, mask: np.ndarray, alpha: float = 0.5) -> Image.Image: | |
| # Convert to RGB if needed (creates temporary image) | |
| rgb_img = img.convert("RGB") | |
| try: | |
| img_np = np.array(rgb_img) | |
| finally: | |
| # Close converted image if it's different from original | |
| if rgb_img is not img: | |
| rgb_img.close() | |
| num_cls = len(LABELS_TO_IDS) | |
| palette = get_palette(num_cls) | |
| colormap = create_colormap(palette) | |
| overlay = np.zeros((*mask.shape, 3), dtype=np.uint8) | |
| for label, idx in LABELS_TO_IDS.items(): | |
| if idx != 0: | |
| overlay[mask == idx] = np.array(colormap(idx)[:3]) * 255 | |
| blended = Image.fromarray(np.uint8(img_np * (1 - alpha) + overlay * alpha)) | |
| return blended | |
| def create_legend_image() -> Image.Image: | |
| num_cls = len(LABELS_TO_IDS) | |
| palette = get_palette(num_cls) | |
| # 2 columns layout | |
| scale = 1 | |
| rows_per_col = (num_cls + 1) // 2 | |
| col_width = 200 * scale | |
| row_height = 35 * scale | |
| legend_width = col_width * 2 | |
| legend_height = rows_per_col * row_height + 20 * scale | |
| # Use context manager for proper cleanup | |
| legend = Image.new("RGB", (legend_width, legend_height), "white") | |
| draw = ImageDraw.Draw(legend) | |
| # Cross-platform font loading | |
| font = None | |
| font_paths = [ | |
| "/System/Library/Fonts/Helvetica.ttc", # macOS | |
| "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", # Linux | |
| "/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf", # Linux | |
| ] | |
| for font_path in font_paths: | |
| try: | |
| font = ImageFont.truetype(font_path, 20 * scale) | |
| break | |
| except (OSError, IOError): | |
| continue | |
| if font is None: | |
| font = ImageFont.load_default() | |
| box_size = 28 * scale | |
| for idx, label in IDS_TO_LABELS.items(): | |
| col = idx // rows_per_col | |
| row = idx % rows_per_col | |
| x = col * col_width + 10 * scale | |
| y = row * row_height + 10 * scale | |
| color = tuple(palette[idx * 3 : idx * 3 + 3]) | |
| draw.rectangle([x, y, x + box_size, y + box_size], fill=color, outline="black", width=2) | |
| draw.text((x + box_size + 10 * scale, y + 5 * scale), f"{idx}: {label}", fill="black", font=font) | |
| return legend | |
| # ----------------- MODEL ----------------- # | |
| print("Loading model...") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| processor = SegformerImageProcessor.from_pretrained(MODEL_ID) | |
| model = SegformerForSemanticSegmentation.from_pretrained(MODEL_ID) | |
| model.eval() | |
| model.to(device) | |
| print(f"Model loaded on {device}!") | |
| def segment(image: Image.Image) -> tuple[Image.Image, Image.Image]: | |
| if image is None: | |
| raise gr.Error("Please upload an image") | |
| # Constrain output size (max 768w or 1152h, whichever hits first) | |
| constrained_image = constrain_image_size(image, max_width=768, max_height=1152) | |
| image_was_resized = constrained_image is not image | |
| try: | |
| inputs = processor(images=constrained_image, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| upsampled = torch.nn.functional.interpolate( | |
| logits, | |
| size=(constrained_image.height, constrained_image.width), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| mask = upsampled.argmax(dim=1).squeeze(0).cpu().numpy() | |
| mask_image = Image.fromarray(mask.astype("uint8")) | |
| blended_image = visualize_mask_with_overlay(constrained_image, mask, alpha=0.5) | |
| return blended_image, mask_image | |
| finally: | |
| # Clean up resized image if one was created | |
| if image_was_resized: | |
| constrained_image.close() | |
| # ----------------- UI ----------------- # | |
| # Pre-generate legend with proper cleanup | |
| legend_path = os.path.join(ASSETS_DIR, "legend.png") | |
| legend_img = create_legend_image() | |
| try: | |
| legend_img.save(legend_path) | |
| finally: | |
| legend_img.close() | |
| # Load examples | |
| examples_dir = os.path.join(ASSETS_DIR, "examples") | |
| example_images = sorted([ | |
| os.path.join(examples_dir, img) | |
| for img in os.listdir(examples_dir) | |
| if img.lower().endswith((".png", ".jpg", ".jpeg", ".webp")) | |
| ]) if os.path.exists(examples_dir) else [] | |
| # Custom CSS | |
| CUSTOM_CSS = """ | |
| .contain img { | |
| object-fit: contain !important; | |
| } | |
| """ | |
| # Load HTML content | |
| with open(os.path.join(os.path.dirname(__file__), "banner.html"), "r") as f: | |
| banner_html = f.read() | |
| with open(os.path.join(os.path.dirname(__file__), "tips.html"), "r") as f: | |
| tips_html = f.read() | |
| # Build UI | |
| with gr.Blocks() as demo: | |
| # Header | |
| gr.HTML(banner_html) | |
| gr.HTML(tips_html) | |
| with gr.Row(equal_height=False): | |
| # Left column: Input | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| label="Input Image", | |
| type="pil", | |
| sources=["upload", "clipboard"], | |
| elem_classes=["contain"], | |
| ) | |
| run_button = gr.Button("Run", variant="primary", size="lg") | |
| if example_images: | |
| gr.Examples( | |
| examples=example_images, | |
| inputs=input_image, | |
| examples_per_page=8, | |
| label="Examples", | |
| ) | |
| # Right column: Results | |
| with gr.Column(scale=1): | |
| result_image = gr.Image( | |
| label="Segmentation Overlay", | |
| type="pil", | |
| interactive=False, | |
| elem_classes=["contain"], | |
| ) | |
| mask_image = gr.Image( | |
| label="Segmentation Mask", | |
| type="pil", | |
| interactive=False, | |
| elem_classes=["contain"], | |
| ) | |
| # Legend in accordion | |
| with gr.Accordion("Label Legend", open=True): | |
| gr.Image( | |
| value=legend_path, | |
| label=None, | |
| show_label=False, | |
| interactive=False, | |
| ) | |
| # Event handler | |
| run_button.click( | |
| fn=segment, | |
| inputs=[input_image], | |
| outputs=[result_image, mask_image], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=False, | |
| css=CUSTOM_CSS, | |
| css_paths=None, | |
| ) | |