import torch import torch.nn as nn import torch.nn.functional as F from torchvision.transforms import Compose, Resize, ToTensor, Normalize import numpy as np from PIL import Image import cv2 # Check if transformers is available try: from transformers import AutoBackbone HAS_TRANSFORMERS = True except ImportError: HAS_TRANSFORMERS = False class SwinMattingModel(nn.Module): """Swin-UNet model for sky masking""" def __init__(self, config): super().__init__() encoder_config = config['encoder'] decoder_config = config['decoder'] self.encoder = SwinEncoder(model_name=encoder_config["model_name"]) self.decoder = MattingDecoder( use_attn=decoder_config["use_attn"], refine_channels=decoder_config["refine_channels"] ) def forward(self, x): features = self.encoder(x) return self.decoder(features, x) class SwinEncoder(nn.Module): """Swin Transformer encoder""" def __init__(self, model_name="microsoft/swin-small-patch4-window7-224"): super().__init__() if HAS_TRANSFORMERS: try: self.backbone = AutoBackbone.from_pretrained( model_name, out_indices=(1, 2, 3, 4), use_safetensors=True, trust_remote_code=False ) self.use_hf_backbone = True except Exception as e: print(f"Failed to load HuggingFace backbone: {e}") self.backbone = self._create_custom_swin() self.use_hf_backbone = False else: self.backbone = self._create_custom_swin() self.use_hf_backbone = False def _create_custom_swin(self): """Fallback Swin-like backbone""" layers = nn.ModuleList() layers.append(nn.Conv2d(3, 96, kernel_size=4, stride=4)) layers.append(nn.Conv2d(96, 192, kernel_size=2, stride=2)) layers.append(nn.Conv2d(192, 384, kernel_size=2, stride=2)) layers.append(nn.Conv2d(384, 768, kernel_size=2, stride=2)) return layers def forward(self, x): if self.use_hf_backbone: outputs = self.backbone(pixel_values=x) features = outputs.feature_maps return list(features) else: features = [] current = x for layer in self.backbone: current = layer(current) features.append(current) return features class MattingDecoder(nn.Module): """U-Net decoder with attention gates""" def __init__(self, use_attn=False, refine_channels=16): super().__init__() self.use_attn = use_attn self.refine_channels = refine_channels # Bottom convolution self.conv_bottom = nn.Conv2d(768, 768, kernel_size=3, padding=1) self.bn_bottom = nn.BatchNorm2d(768) # Upsample + fuse with skip connections self.conv_up3 = nn.Conv2d(768 + 384, 384, kernel_size=3, padding=1) self.bn_up3 = nn.BatchNorm2d(384) self.conv_up2 = nn.Conv2d(384 + 192, 192, kernel_size=3, padding=1) self.bn_up2 = nn.BatchNorm2d(192) self.conv_up1 = nn.Conv2d(192 + 96, 96, kernel_size=3, padding=1) self.bn_up1 = nn.BatchNorm2d(96) self.conv_out = nn.Conv2d(96, 1, kernel_size=3, padding=1) # Detail refinement self.refine_conv1 = nn.Conv2d(4, self.refine_channels, kernel_size=3, padding=1) self.bn_refine1 = nn.BatchNorm2d(self.refine_channels) self.refine_conv2 = nn.Conv2d(self.refine_channels, self.refine_channels, kernel_size=3, padding=1) self.bn_refine2 = nn.BatchNorm2d(self.refine_channels) self.refine_conv3 = nn.Conv2d(self.refine_channels, 1, kernel_size=3, padding=1) # Attention gates if self.use_attn: self.reduce_768_to_384 = nn.Conv2d(768, 384, kernel_size=1) self.reduce_384_to_192 = nn.Conv2d(384, 192, kernel_size=1) self.reduce_192_to_96 = nn.Conv2d(192, 96, kernel_size=1) self.gate_16 = nn.Conv2d(384, 384, kernel_size=1) self.skip_16 = nn.Conv2d(384, 384, kernel_size=1) self.gate_8 = nn.Conv2d(192, 192, kernel_size=1) self.skip_8 = nn.Conv2d(192, 192, kernel_size=1) self.gate_4 = nn.Conv2d(96, 96, kernel_size=1) self.skip_4 = nn.Conv2d(96, 96, kernel_size=1) def forward(self, features, original_image): f1, f2, f3, f4 = features # Bottom (1/32) x = F.relu(self.bn_bottom(self.conv_bottom(f4))) # 1/16 stage x = F.interpolate(x, scale_factor=2.0, mode='nearest') if self.use_attn: x_reduced = self.reduce_768_to_384(x) g = self.gate_16(x_reduced) skip = self.skip_16(f3) att = torch.sigmoid(g + skip) f3 = f3 * att x = torch.cat([x, f3], dim=1) x = F.relu(self.bn_up3(self.conv_up3(x))) # 1/8 stage x = F.interpolate(x, scale_factor=2.0, mode='nearest') if self.use_attn: x_reduced = self.reduce_384_to_192(x) g = self.gate_8(x_reduced) skip = self.skip_8(f2) att = torch.sigmoid(g + skip) f2 = f2 * att x = torch.cat([x, f2], dim=1) x = F.relu(self.bn_up2(self.conv_up2(x))) # 1/4 stage x = F.interpolate(x, scale_factor=2.0, mode='nearest') if self.use_attn: x_reduced = self.reduce_192_to_96(x) g = self.gate_4(x_reduced) skip = self.skip_4(f1) att = torch.sigmoid(g + skip) f1 = f1 * att x = torch.cat([x, f1], dim=1) x = F.relu(self.bn_up1(self.conv_up1(x))) # Upsample to full resolution and predict coarse alpha x = F.interpolate(x, size=original_image.shape[-2:], mode='nearest') coarse_alpha = self.conv_out(x) # Detail refinement refine_input = torch.cat([coarse_alpha, original_image], dim=1) r = F.relu(self.bn_refine1(self.refine_conv1(refine_input))) r = F.relu(self.bn_refine2(self.refine_conv2(r))) refined_alpha = self.refine_conv3(r) return torch.sigmoid(refined_alpha) class SkyMaskingPipeline: """Main sky masking pipeline""" def __init__(self, model_path="swin_small_patch4_window7_224.pt"): self.transforms = Compose([ Resize(size=(512, 512)), ToTensor(), Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model_path = model_path self.model = self._load_model() print(f"🎯 Sky masking pipeline initialized on {self.device}") def generate_mask(self, image: Image.Image) -> np.ndarray: """Generate sky mask from input image""" if self.model is None: raise RuntimeError("Model is not loaded.") # Store original size original_size = image.size # Apply transforms and run inference tensor = self.transforms(image).unsqueeze(0).to(self.device) with torch.inference_mode(): output = self.model(tensor) output = output.detach().cpu().numpy() output = np.clip(output, a_min=0, a_max=1) # Get alpha matte and resize to original dimensions alpha_matte = np.squeeze(output, axis=0).squeeze() mask_resized = cv2.resize(alpha_matte, original_size, interpolation=cv2.INTER_LINEAR) # Convert to uint8 mask_uint8 = (mask_resized * 255).astype(np.uint8) return mask_uint8 def _load_model(self): """Load model with downloaded weights""" model = SwinMattingModel({ "encoder": { "model_name": "microsoft/swin-small-patch4-window7-224" }, "decoder": { "use_attn": True, "refine_channels": 16 } }) self._load_checkpoint(model) model.to(self.device) model.eval() return model def _load_checkpoint(self, model): """Load checkpoint with error handling""" try: checkpoint = torch.load(self.model_path, map_location="cpu", weights_only=True) except Exception as e: print(f"Safe loading failed: {e}") try: checkpoint = torch.load(self.model_path, map_location="cpu", weights_only=False) print("Warning: Used weights_only=False. Only use trusted model files.") except Exception as e2: print(f"Failed to load checkpoint: {e2}") return try: missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) if missing_keys: print(f"Missing keys: {missing_keys}") if unexpected_keys: print(f"Unexpected keys: {unexpected_keys}") print("✅ Model loaded successfully!") except Exception as e: print(f"Failed to load state dict: {e}")