| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Dict, List, Tuple, Optional |
| from enum import Enum, auto |
|
|
| class ClothingType(Enum): |
| SHIRT = auto() |
| PANTS = auto() |
| DRESS = auto() |
| SKIRT = auto() |
| JACKET = auto() |
| SHOES = auto() |
| ACCESSORY = auto() |
|
|
| class GarmentFeatures(Enum): |
| COLLAR = auto() |
| SLEEVES = auto() |
| BUTTONS = auto() |
| POCKETS = auto() |
| PATTERNS = auto() |
| TEXTURE = auto() |
|
|
| class StyleTransferModel(nn.Module): |
| def __init__(self, config: dict = {"style_layers": 5}): |
| super().__init__() |
| self.content_layers = self._build_content_layers() |
| self.style_layers = self._build_style_layers(config["style_layers"]) |
| self.pattern_generator = PatternGenerator() |
| self.texture_mixer = TextureMixer() |
| |
| def _build_content_layers(self): |
| return nn.ModuleList([ |
| nn.Conv2d(64 * (2 ** i), 64 * (2 ** i), 3, padding=1) |
| for i in range(4) |
| ]) |
| |
| def _build_style_layers(self, num_layers): |
| return nn.ModuleList([ |
| nn.Conv2d(64 * (2 ** (i % 3)), 64 * (2 ** (i % 3)), 3, padding=1) |
| for i in range(num_layers) |
| ]) |
|
|
| class GarmentAnalyzer(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.feature_detector = nn.ModuleDict({ |
| feature.name.lower(): self._build_detector() |
| for feature in GarmentFeatures |
| }) |
| |
| def _build_detector(self): |
| return nn.Sequential( |
| nn.Conv2d(128, 64, 3, padding=1), |
| nn.GroupNorm(8, 64), |
| nn.LeakyReLU(0.2), |
| nn.Conv2d(64, 1, 1) |
| ) |
| |
| def detect_features(self, garment_image: torch.Tensor) -> Dict[str, torch.Tensor]: |
| results = {} |
| features = self.shared_backbone(garment_image) |
| |
| for feature_name, detector in self.feature_detector.items(): |
| results[feature_name] = detector(features) |
| |
| return results |
|
|
| class AdvancedClothingManipulator(nn.Module): |
| def __init__(self, config: dict = {}): |
| super().__init__() |
| self.garment_analyzer = GarmentAnalyzer() |
| self.style_transfer = StyleTransferModel(config) |
| self.pattern_modifier = PatternModificationModule() |
| self.detail_preserving_blender = DetailPreservingBlender() |
| |
| def transfer_style(self, |
| source_image: torch.Tensor, |
| target_style: Dict[str, torch.Tensor], |
| preserve_features: List[GarmentFeatures] = None) -> torch.Tensor: |
| |
| source_features = self.garment_analyzer.detect_features(source_image) |
| |
| |
| if preserve_features: |
| preservation_masks = self._create_preservation_masks(source_features, preserve_features) |
| return self.detail_preserving_blender( |
| source_image, |
| self.style_transfer(source_image, target_style), |
| preservation_masks |
| ) |
| |
| return self.style_transfer(source_image, target_style) |
| |
| def modify_pattern(self, |
| garment_image: torch.Tensor, |
| pattern_type: str, |
| pattern_params: Dict[str, float]) -> torch.Tensor: |
| return self.pattern_modifier(garment_image, pattern_type, pattern_params) |
|
|
| class PatternModificationModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.pattern_encoders = nn.ModuleDict({ |
| 'stripes': self._build_pattern_encoder(), |
| 'dots': self._build_pattern_encoder(), |
| 'floral': self._build_pattern_encoder(), |
| 'geometric': self._build_pattern_encoder() |
| }) |
| |
| self.pattern_generator = nn.Sequential( |
| nn.Conv2d(128, 256, 3, padding=1), |
| nn.GroupNorm(8, 256), |
| nn.LeakyReLU(0.2), |
| nn.Conv2d(256, 3, 3, padding=1), |
| nn.Tanh() |
| ) |
| |
| def _build_pattern_encoder(self): |
| return nn.Sequential( |
| nn.Conv2d(3, 64, 3, padding=1), |
| nn.GroupNorm(8, 64), |
| nn.LeakyReLU(0.2), |
| nn.Conv2d(64, 128, 3, padding=1), |
| nn.GroupNorm(8, 128), |
| nn.LeakyReLU(0.2) |
| ) |
|
|
| class DetailPreservingBlender(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.attention_module = nn.Sequential( |
| nn.Conv2d(6, 32, 3, padding=1), |
| nn.GroupNorm(8, 32), |
| nn.LeakyReLU(0.2), |
| nn.Conv2d(32, 1, 1), |
| nn.Sigmoid() |
| ) |
| |
| def forward(self, source: torch.Tensor, styled: torch.Tensor, |
| preservation_masks: Dict[str, torch.Tensor]) -> torch.Tensor: |
| |
| combined = torch.cat([source, styled], dim=1) |
| attention = self.attention_module(combined) |
| |
| result = styled.clone() |
| for feature, mask in preservation_masks.items(): |
| result = result * (1 - mask) + source * mask |
| |
| return result |
|
|
| class TextureMixer(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.encoder = nn.Sequential( |
| nn.Conv2d(3, 64, 3, padding=1), |
| nn.GroupNorm(8, 64), |
| nn.LeakyReLU(0.2), |
| nn.Conv2d(64, 128, 3, padding=1), |
| nn.GroupNorm(8, 128), |
| nn.LeakyReLU(0.2) |
| ) |
| |
| self.mixer = nn.Sequential( |
| nn.Conv2d(256, 128, 3, padding=1), |
| nn.GroupNorm(8, 128), |
| nn.LeakyReLU(0.2), |
| nn.Conv2d(128, 3, 3, padding=1), |
| nn.Tanh() |
| ) |
| |
| def mix_textures(self, texture1: torch.Tensor, texture2: torch.Tensor, |
| mix_ratio: float = 0.5) -> torch.Tensor: |
| feat1 = self.encoder(texture1) |
| feat2 = self.encoder(texture2) |
| |
| mixed_features = feat1 * mix_ratio + feat2 * (1 - mix_ratio) |
| return self.mixer(torch.cat([feat1, feat2], dim=1)) |
|
|
| class PatternGenerator(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.generator = nn.Sequential( |
| nn.ConvTranspose2d(128, 256, 4, stride=2, padding=1), |
| nn.GroupNorm(8, 256), |
| nn.LeakyReLU(0.2), |
| nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), |
| nn.GroupNorm(8, 128), |
| nn.LeakyReLU(0.2), |
| nn.Conv2d(128, 3, 3, padding=1), |
| nn.Tanh() |
| ) |
| |
| def generate_pattern(self, pattern_code: torch.Tensor, scale: float = 1.0) -> torch.Tensor: |
| base_pattern = self.generator(pattern_code) |
| return F.interpolate(base_pattern, |
| scale_factor=scale, |
| mode='bilinear', |
| align_corners=True) |
|
|
| def example_usage(): |
| |
| manipulator = AdvancedClothingManipulator() |
| |
| |
| new_garment = manipulator.transfer_style( |
| source_image=source, |
| target_style={"style_image": style_ref, "style_strength": 0.8}, |
| preserve_features=[GarmentFeatures.COLLAR, GarmentFeatures.BUTTONS] |
| ) |
| |
| |
| patterned_garment = manipulator.modify_pattern( |
| garment_image=garment, |
| pattern_type="stripes", |
| pattern_params={ |
| "scale": 0.5, |
| "rotation": 45, |
| "density": 0.8 |
| } |
| ) |