Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision.transforms import Compose, Resize, ToTensor, Normalize | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| # Check if transformers is available | |
| try: | |
| from transformers import AutoBackbone | |
| HAS_TRANSFORMERS = True | |
| except ImportError: | |
| HAS_TRANSFORMERS = False | |
| class SwinMattingModel(nn.Module): | |
| """Swin-UNet model for sky masking""" | |
| def __init__(self, config): | |
| super().__init__() | |
| encoder_config = config['encoder'] | |
| decoder_config = config['decoder'] | |
| self.encoder = SwinEncoder(model_name=encoder_config["model_name"]) | |
| self.decoder = MattingDecoder( | |
| use_attn=decoder_config["use_attn"], | |
| refine_channels=decoder_config["refine_channels"] | |
| ) | |
| def forward(self, x): | |
| features = self.encoder(x) | |
| return self.decoder(features, x) | |
| class SwinEncoder(nn.Module): | |
| """Swin Transformer encoder""" | |
| def __init__(self, model_name="microsoft/swin-small-patch4-window7-224"): | |
| super().__init__() | |
| if HAS_TRANSFORMERS: | |
| try: | |
| self.backbone = AutoBackbone.from_pretrained( | |
| model_name, | |
| out_indices=(1, 2, 3, 4), | |
| use_safetensors=True, | |
| trust_remote_code=False | |
| ) | |
| self.use_hf_backbone = True | |
| except Exception as e: | |
| print(f"Failed to load HuggingFace backbone: {e}") | |
| self.backbone = self._create_custom_swin() | |
| self.use_hf_backbone = False | |
| else: | |
| self.backbone = self._create_custom_swin() | |
| self.use_hf_backbone = False | |
| def _create_custom_swin(self): | |
| """Fallback Swin-like backbone""" | |
| layers = nn.ModuleList() | |
| layers.append(nn.Conv2d(3, 96, kernel_size=4, stride=4)) | |
| layers.append(nn.Conv2d(96, 192, kernel_size=2, stride=2)) | |
| layers.append(nn.Conv2d(192, 384, kernel_size=2, stride=2)) | |
| layers.append(nn.Conv2d(384, 768, kernel_size=2, stride=2)) | |
| return layers | |
| def forward(self, x): | |
| if self.use_hf_backbone: | |
| outputs = self.backbone(pixel_values=x) | |
| features = outputs.feature_maps | |
| return list(features) | |
| else: | |
| features = [] | |
| current = x | |
| for layer in self.backbone: | |
| current = layer(current) | |
| features.append(current) | |
| return features | |
| class MattingDecoder(nn.Module): | |
| """U-Net decoder with attention gates""" | |
| def __init__(self, use_attn=False, refine_channels=16): | |
| super().__init__() | |
| self.use_attn = use_attn | |
| self.refine_channels = refine_channels | |
| # Bottom convolution | |
| self.conv_bottom = nn.Conv2d(768, 768, kernel_size=3, padding=1) | |
| self.bn_bottom = nn.BatchNorm2d(768) | |
| # Upsample + fuse with skip connections | |
| self.conv_up3 = nn.Conv2d(768 + 384, 384, kernel_size=3, padding=1) | |
| self.bn_up3 = nn.BatchNorm2d(384) | |
| self.conv_up2 = nn.Conv2d(384 + 192, 192, kernel_size=3, padding=1) | |
| self.bn_up2 = nn.BatchNorm2d(192) | |
| self.conv_up1 = nn.Conv2d(192 + 96, 96, kernel_size=3, padding=1) | |
| self.bn_up1 = nn.BatchNorm2d(96) | |
| self.conv_out = nn.Conv2d(96, 1, kernel_size=3, padding=1) | |
| # Detail refinement | |
| self.refine_conv1 = nn.Conv2d(4, self.refine_channels, kernel_size=3, padding=1) | |
| self.bn_refine1 = nn.BatchNorm2d(self.refine_channels) | |
| self.refine_conv2 = nn.Conv2d(self.refine_channels, self.refine_channels, kernel_size=3, padding=1) | |
| self.bn_refine2 = nn.BatchNorm2d(self.refine_channels) | |
| self.refine_conv3 = nn.Conv2d(self.refine_channels, 1, kernel_size=3, padding=1) | |
| # Attention gates | |
| if self.use_attn: | |
| self.reduce_768_to_384 = nn.Conv2d(768, 384, kernel_size=1) | |
| self.reduce_384_to_192 = nn.Conv2d(384, 192, kernel_size=1) | |
| self.reduce_192_to_96 = nn.Conv2d(192, 96, kernel_size=1) | |
| self.gate_16 = nn.Conv2d(384, 384, kernel_size=1) | |
| self.skip_16 = nn.Conv2d(384, 384, kernel_size=1) | |
| self.gate_8 = nn.Conv2d(192, 192, kernel_size=1) | |
| self.skip_8 = nn.Conv2d(192, 192, kernel_size=1) | |
| self.gate_4 = nn.Conv2d(96, 96, kernel_size=1) | |
| self.skip_4 = nn.Conv2d(96, 96, kernel_size=1) | |
| def forward(self, features, original_image): | |
| f1, f2, f3, f4 = features | |
| # Bottom (1/32) | |
| x = F.relu(self.bn_bottom(self.conv_bottom(f4))) | |
| # 1/16 stage | |
| x = F.interpolate(x, scale_factor=2.0, mode='nearest') | |
| if self.use_attn: | |
| x_reduced = self.reduce_768_to_384(x) | |
| g = self.gate_16(x_reduced) | |
| skip = self.skip_16(f3) | |
| att = torch.sigmoid(g + skip) | |
| f3 = f3 * att | |
| x = torch.cat([x, f3], dim=1) | |
| x = F.relu(self.bn_up3(self.conv_up3(x))) | |
| # 1/8 stage | |
| x = F.interpolate(x, scale_factor=2.0, mode='nearest') | |
| if self.use_attn: | |
| x_reduced = self.reduce_384_to_192(x) | |
| g = self.gate_8(x_reduced) | |
| skip = self.skip_8(f2) | |
| att = torch.sigmoid(g + skip) | |
| f2 = f2 * att | |
| x = torch.cat([x, f2], dim=1) | |
| x = F.relu(self.bn_up2(self.conv_up2(x))) | |
| # 1/4 stage | |
| x = F.interpolate(x, scale_factor=2.0, mode='nearest') | |
| if self.use_attn: | |
| x_reduced = self.reduce_192_to_96(x) | |
| g = self.gate_4(x_reduced) | |
| skip = self.skip_4(f1) | |
| att = torch.sigmoid(g + skip) | |
| f1 = f1 * att | |
| x = torch.cat([x, f1], dim=1) | |
| x = F.relu(self.bn_up1(self.conv_up1(x))) | |
| # Upsample to full resolution and predict coarse alpha | |
| x = F.interpolate(x, size=original_image.shape[-2:], mode='nearest') | |
| coarse_alpha = self.conv_out(x) | |
| # Detail refinement | |
| refine_input = torch.cat([coarse_alpha, original_image], dim=1) | |
| r = F.relu(self.bn_refine1(self.refine_conv1(refine_input))) | |
| r = F.relu(self.bn_refine2(self.refine_conv2(r))) | |
| refined_alpha = self.refine_conv3(r) | |
| return torch.sigmoid(refined_alpha) | |
| class SkyMaskingPipeline: | |
| """Main sky masking pipeline""" | |
| def __init__(self, model_path="swin_small_patch4_window7_224.pt"): | |
| self.transforms = Compose([ | |
| Resize(size=(512, 512)), | |
| ToTensor(), | |
| Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ]) | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.model_path = model_path | |
| self.model = self._load_model() | |
| print(f"🎯 Sky masking pipeline initialized on {self.device}") | |
| def generate_mask(self, image: Image.Image) -> np.ndarray: | |
| """Generate sky mask from input image""" | |
| if self.model is None: | |
| raise RuntimeError("Model is not loaded.") | |
| # Store original size | |
| original_size = image.size | |
| # Apply transforms and run inference | |
| tensor = self.transforms(image).unsqueeze(0).to(self.device) | |
| with torch.inference_mode(): | |
| output = self.model(tensor) | |
| output = output.detach().cpu().numpy() | |
| output = np.clip(output, a_min=0, a_max=1) | |
| # Get alpha matte and resize to original dimensions | |
| alpha_matte = np.squeeze(output, axis=0).squeeze() | |
| mask_resized = cv2.resize(alpha_matte, original_size, interpolation=cv2.INTER_LINEAR) | |
| # Convert to uint8 | |
| mask_uint8 = (mask_resized * 255).astype(np.uint8) | |
| return mask_uint8 | |
| def _load_model(self): | |
| """Load model with downloaded weights""" | |
| model = SwinMattingModel({ | |
| "encoder": { | |
| "model_name": "microsoft/swin-small-patch4-window7-224" | |
| }, | |
| "decoder": { | |
| "use_attn": True, | |
| "refine_channels": 16 | |
| } | |
| }) | |
| self._load_checkpoint(model) | |
| model.to(self.device) | |
| model.eval() | |
| return model | |
| def _load_checkpoint(self, model): | |
| """Load checkpoint with error handling""" | |
| try: | |
| checkpoint = torch.load(self.model_path, map_location="cpu", weights_only=True) | |
| except Exception as e: | |
| print(f"Safe loading failed: {e}") | |
| try: | |
| checkpoint = torch.load(self.model_path, map_location="cpu", weights_only=False) | |
| print("Warning: Used weights_only=False. Only use trusted model files.") | |
| except Exception as e2: | |
| print(f"Failed to load checkpoint: {e2}") | |
| return | |
| try: | |
| missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) | |
| if missing_keys: | |
| print(f"Missing keys: {missing_keys}") | |
| if unexpected_keys: | |
| print(f"Unexpected keys: {unexpected_keys}") | |
| print("✅ Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Failed to load state dict: {e}") | |