Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation | |
| import torch.nn.functional as F | |
| import logging | |
| import time | |
| from typing import Tuple, Optional | |
| logger = logging.getLogger('looks.studio.segformer') | |
| class SegformerParser: | |
| def __init__(self, model_path="mattmdjaga/segformer_b2_clothes"): | |
| self.start_time = time.time() | |
| logger.info(f"Initializing SegformerParser with model: {model_path}") | |
| try: | |
| self.processor = SegformerImageProcessor.from_pretrained(model_path) | |
| self.model = AutoModelForSemanticSegmentation.from_pretrained(model_path) | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {self.device}") | |
| self.model.to(self.device) | |
| # Define clothing-related labels | |
| self.clothing_labels = { | |
| 4: "upper-clothes", | |
| 5: "skirt", | |
| 6: "pants", | |
| 7: "dress", | |
| 8: "belt", | |
| 9: "left-shoe", | |
| 10: "right-shoe", | |
| 14: "left-arm", | |
| 15: "right-arm", | |
| 17: "scarf" | |
| } | |
| logger.info(f"SegformerParser initialized in {time.time() - self.start_time:.2f} seconds") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize SegformerParser: {str(e)}") | |
| raise | |
| def _resize_image(self, image: Image.Image, max_size: int = 1024) -> Tuple[Image.Image, float]: | |
| """Resize image while maintaining aspect ratio if it exceeds max_size""" | |
| width, height = image.size | |
| scale = 1.0 | |
| if width > max_size or height > max_size: | |
| scale = max_size / max(width, height) | |
| new_width = int(width * scale) | |
| new_height = int(height * scale) | |
| logger.info(f"Resizing image from {width}x{height} to {new_width}x{new_height}") | |
| image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| return image, scale | |
| def _validate_image(self, image: Image.Image) -> bool: | |
| """Validate input image""" | |
| if not isinstance(image, Image.Image): | |
| logger.error("Input is not a PIL Image") | |
| return False | |
| if image.mode not in ['RGB', 'RGBA']: | |
| logger.error(f"Unsupported image mode: {image.mode}") | |
| return False | |
| width, height = image.size | |
| if width < 64 or height < 64: | |
| logger.error(f"Image too small: {width}x{height}") | |
| return False | |
| if width > 4096 or height > 4096: | |
| logger.error(f"Image too large: {width}x{height}") | |
| return False | |
| return True | |
| def get_image_mask(self, image: Image.Image) -> Optional[Image.Image]: | |
| """Generate segmentation mask for clothing""" | |
| start_time = time.time() | |
| logger.info(f"Starting segmentation for image size: {image.size}") | |
| try: | |
| # Validate input image | |
| if not self._validate_image(image): | |
| return None | |
| # Convert RGBA to RGB if necessary | |
| if image.mode == 'RGBA': | |
| logger.info("Converting RGBA to RGB") | |
| image = image.convert('RGB') | |
| # Resize image if too large | |
| image, scale = self._resize_image(image) | |
| # Process the image | |
| logger.info("Processing image with Segformer") | |
| inputs = self.processor(images=image, return_tensors="pt").to(self.device) | |
| # Get predictions | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| logits = outputs.logits.cpu() | |
| # Upsample logits to original image size | |
| upsampled_logits = F.interpolate( | |
| logits, | |
| size=image.size[::-1], | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| # Get the segmentation mask | |
| pred_seg = upsampled_logits.argmax(dim=1)[0] | |
| # Create a binary mask for clothing | |
| mask = torch.zeros_like(pred_seg) | |
| for label_id in self.clothing_labels.keys(): | |
| mask[pred_seg == label_id] = 255 | |
| # Convert to PIL Image | |
| mask = Image.fromarray(mask.numpy().astype(np.uint8)) | |
| # Resize mask back to original size if needed | |
| if scale != 1.0: | |
| original_size = (int(image.size[0] / scale), int(image.size[1] / scale)) | |
| logger.info(f"Resizing mask back to original size: {original_size}") | |
| mask = mask.resize(original_size, Image.Resampling.NEAREST) | |
| logger.info(f"Segmentation completed in {time.time() - start_time:.2f} seconds") | |
| return mask | |
| except Exception as e: | |
| logger.error(f"Error during segmentation: {str(e)}") | |
| return None | |
| def get_all_masks(self, image: Image.Image) -> dict: | |
| """Return a dict of binary masks for each clothing part label.""" | |
| start_time = time.time() | |
| logger.info(f"Starting per-part segmentation for image size: {image.size}") | |
| masks = {} | |
| try: | |
| # Validate input image | |
| if not self._validate_image(image): | |
| return masks | |
| # Convert RGBA to RGB if necessary | |
| if image.mode == 'RGBA': | |
| logger.info("Converting RGBA to RGB") | |
| image = image.convert('RGB') | |
| # Resize image if too large | |
| image, scale = self._resize_image(image) | |
| # Process the image | |
| logger.info("Processing image with Segformer for all masks") | |
| inputs = self.processor(images=image, return_tensors="pt").to(self.device) | |
| # Get predictions | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| logits = outputs.logits.cpu() | |
| upsampled_logits = F.interpolate( | |
| logits, | |
| size=image.size[::-1], | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| pred_seg = upsampled_logits.argmax(dim=1)[0] | |
| # For each clothing label, create a binary mask | |
| for label_id, part_name in self.clothing_labels.items(): | |
| mask = (pred_seg == label_id).numpy().astype(np.uint8) * 255 | |
| mask_img = Image.fromarray(mask) | |
| # Resize mask back to original size if needed | |
| if scale != 1.0: | |
| original_size = (int(image.size[0] / scale), int(image.size[1] / scale)) | |
| mask_img = mask_img.resize(original_size, Image.Resampling.NEAREST) | |
| masks[part_name] = mask_img | |
| logger.info(f"Per-part segmentation completed in {time.time() - start_time:.2f} seconds") | |
| return masks | |
| except Exception as e: | |
| logger.error(f"Error during per-part segmentation: {str(e)}") | |
| return masks |