| from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation | |
| from PIL import Image | |
| import numpy as np | |
| import requests | |
| import torch.nn.functional as F | |
| import torch | |
| import os | |
| def generate_clothing_mask( | |
| image_path: str, | |
| label: int, | |
| output_path: str = "./test/output_mask.png", | |
| model_name: str = "mattmdjaga/segformer_b2_clothes", | |
| ) -> Image.Image: | |
| """ | |
| Генерирует бинарную маску для указанного класса одежды и сохраняет её | |
| Args: | |
| image_path: Путь к изображению или URL | |
| label: Класс для сегментации (0-17) | |
| output_path: Путь для сохранения маски | |
| model_name: Название модели HuggingFace | |
| show_result: Показать результат matplotlib | |
| Returns: | |
| PIL.Image: Бинарная маска (белый - выбранный класс, черный - остальное) | |
| """ | |
| processor = SegformerImageProcessor.from_pretrained(model_name) | |
| model = AutoModelForSemanticSegmentation.from_pretrained(model_name) | |
| if image_path.startswith(('http://', 'https://')): | |
| image = Image.open(requests.get(image_path, stream=True).raw) | |
| else: | |
| image = Image.open(image_path) | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| image_np = np.array(image) | |
| if len(image_np.shape) != 3 or image_np.shape[2] != 3: | |
| raise ValueError("Изображение должно быть в формате RGB (H, W, 3)") | |
| inputs = processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| upsampled_logits = F.interpolate( | |
| logits, | |
| size=image.size[::-1], | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| pred_seg = upsampled_logits.argmax(dim=1)[0] | |
| mask = (pred_seg == label).numpy().astype('uint8') * 255 | |
| mask_image = Image.fromarray(mask) | |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
| mask_image.save(output_path) | |
| return mask_image | |