CutItOut / briacustom.py
K00B404's picture
Update briacustom.py
69dcc94 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Tuple, Optional
from enum import Enum, auto
class ClothingType(Enum):
SHIRT = auto()
PANTS = auto()
DRESS = auto()
SKIRT = auto()
JACKET = auto()
SHOES = auto()
ACCESSORY = auto()
class GarmentFeatures(Enum):
COLLAR = auto()
SLEEVES = auto()
BUTTONS = auto()
POCKETS = auto()
PATTERNS = auto()
TEXTURE = auto()
class StyleTransferModel(nn.Module):
def __init__(self, config: dict = {"style_layers": 5}):
super().__init__()
self.content_layers = self._build_content_layers()
self.style_layers = self._build_style_layers(config["style_layers"])
self.pattern_generator = PatternGenerator()
self.texture_mixer = TextureMixer()
def _build_content_layers(self):
return nn.ModuleList([
nn.Conv2d(64 * (2 ** i), 64 * (2 ** i), 3, padding=1)
for i in range(4)
])
def _build_style_layers(self, num_layers):
return nn.ModuleList([
nn.Conv2d(64 * (2 ** (i % 3)), 64 * (2 ** (i % 3)), 3, padding=1)
for i in range(num_layers)
])
class GarmentAnalyzer(nn.Module):
def __init__(self):
super().__init__()
self.feature_detector = nn.ModuleDict({
feature.name.lower(): self._build_detector()
for feature in GarmentFeatures
})
def _build_detector(self):
return nn.Sequential(
nn.Conv2d(128, 64, 3, padding=1),
nn.GroupNorm(8, 64),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 1, 1)
)
def detect_features(self, garment_image: torch.Tensor) -> Dict[str, torch.Tensor]:
results = {}
features = self.shared_backbone(garment_image)
for feature_name, detector in self.feature_detector.items():
results[feature_name] = detector(features)
return results
class AdvancedClothingManipulator(nn.Module):
def __init__(self, config: dict = {}):
super().__init__()
self.garment_analyzer = GarmentAnalyzer()
self.style_transfer = StyleTransferModel(config)
self.pattern_modifier = PatternModificationModule()
self.detail_preserving_blender = DetailPreservingBlender()
def transfer_style(self,
source_image: torch.Tensor,
target_style: Dict[str, torch.Tensor],
preserve_features: List[GarmentFeatures] = None) -> torch.Tensor:
# Analyze source garment
source_features = self.garment_analyzer.detect_features(source_image)
# Transfer style while preserving specified features
if preserve_features:
preservation_masks = self._create_preservation_masks(source_features, preserve_features)
return self.detail_preserving_blender(
source_image,
self.style_transfer(source_image, target_style),
preservation_masks
)
return self.style_transfer(source_image, target_style)
def modify_pattern(self,
garment_image: torch.Tensor,
pattern_type: str,
pattern_params: Dict[str, float]) -> torch.Tensor:
return self.pattern_modifier(garment_image, pattern_type, pattern_params)
class PatternModificationModule(nn.Module):
def __init__(self):
super().__init__()
self.pattern_encoders = nn.ModuleDict({
'stripes': self._build_pattern_encoder(),
'dots': self._build_pattern_encoder(),
'floral': self._build_pattern_encoder(),
'geometric': self._build_pattern_encoder()
})
self.pattern_generator = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.GroupNorm(8, 256),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 3, 3, padding=1),
nn.Tanh()
)
def _build_pattern_encoder(self):
return nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.GroupNorm(8, 64),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 3, padding=1),
nn.GroupNorm(8, 128),
nn.LeakyReLU(0.2)
)
class DetailPreservingBlender(nn.Module):
def __init__(self):
super().__init__()
self.attention_module = nn.Sequential(
nn.Conv2d(6, 32, 3, padding=1),
nn.GroupNorm(8, 32),
nn.LeakyReLU(0.2),
nn.Conv2d(32, 1, 1),
nn.Sigmoid()
)
def forward(self, source: torch.Tensor, styled: torch.Tensor,
preservation_masks: Dict[str, torch.Tensor]) -> torch.Tensor:
# Combine source and styled images based on feature preservation masks
combined = torch.cat([source, styled], dim=1)
attention = self.attention_module(combined)
result = styled.clone()
for feature, mask in preservation_masks.items():
result = result * (1 - mask) + source * mask
return result
class TextureMixer(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.GroupNorm(8, 64),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 3, padding=1),
nn.GroupNorm(8, 128),
nn.LeakyReLU(0.2)
)
self.mixer = nn.Sequential(
nn.Conv2d(256, 128, 3, padding=1),
nn.GroupNorm(8, 128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 3, 3, padding=1),
nn.Tanh()
)
def mix_textures(self, texture1: torch.Tensor, texture2: torch.Tensor,
mix_ratio: float = 0.5) -> torch.Tensor:
feat1 = self.encoder(texture1)
feat2 = self.encoder(texture2)
mixed_features = feat1 * mix_ratio + feat2 * (1 - mix_ratio)
return self.mixer(torch.cat([feat1, feat2], dim=1))
class PatternGenerator(nn.Module):
def __init__(self):
super().__init__()
self.generator = nn.Sequential(
nn.ConvTranspose2d(128, 256, 4, stride=2, padding=1),
nn.GroupNorm(8, 256),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
nn.GroupNorm(8, 128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 3, 3, padding=1),
nn.Tanh()
)
def generate_pattern(self, pattern_code: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
base_pattern = self.generator(pattern_code)
return F.interpolate(base_pattern,
scale_factor=scale,
mode='bilinear',
align_corners=True)
def example_usage():
# Initialize models
manipulator = AdvancedClothingManipulator()
# Example: Transfer style while preserving collar and buttons
new_garment = manipulator.transfer_style(
source_image=source,
target_style={"style_image": style_ref, "style_strength": 0.8},
preserve_features=[GarmentFeatures.COLLAR, GarmentFeatures.BUTTONS]
)
# Example: Add or modify patterns
patterned_garment = manipulator.modify_pattern(
garment_image=garment,
pattern_type="stripes",
pattern_params={
"scale": 0.5,
"rotation": 45,
"density": 0.8
}
)