Sky_Replace / sky_masking.py
Mohamed Hassanain
Initial setup: Sky replacement with universal edge optimization
1b3cd5d
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
import numpy as np
from PIL import Image
import cv2
# Check if transformers is available
try:
from transformers import AutoBackbone
HAS_TRANSFORMERS = True
except ImportError:
HAS_TRANSFORMERS = False
class SwinMattingModel(nn.Module):
"""Swin-UNet model for sky masking"""
def __init__(self, config):
super().__init__()
encoder_config = config['encoder']
decoder_config = config['decoder']
self.encoder = SwinEncoder(model_name=encoder_config["model_name"])
self.decoder = MattingDecoder(
use_attn=decoder_config["use_attn"],
refine_channels=decoder_config["refine_channels"]
)
def forward(self, x):
features = self.encoder(x)
return self.decoder(features, x)
class SwinEncoder(nn.Module):
"""Swin Transformer encoder"""
def __init__(self, model_name="microsoft/swin-small-patch4-window7-224"):
super().__init__()
if HAS_TRANSFORMERS:
try:
self.backbone = AutoBackbone.from_pretrained(
model_name,
out_indices=(1, 2, 3, 4),
use_safetensors=True,
trust_remote_code=False
)
self.use_hf_backbone = True
except Exception as e:
print(f"Failed to load HuggingFace backbone: {e}")
self.backbone = self._create_custom_swin()
self.use_hf_backbone = False
else:
self.backbone = self._create_custom_swin()
self.use_hf_backbone = False
def _create_custom_swin(self):
"""Fallback Swin-like backbone"""
layers = nn.ModuleList()
layers.append(nn.Conv2d(3, 96, kernel_size=4, stride=4))
layers.append(nn.Conv2d(96, 192, kernel_size=2, stride=2))
layers.append(nn.Conv2d(192, 384, kernel_size=2, stride=2))
layers.append(nn.Conv2d(384, 768, kernel_size=2, stride=2))
return layers
def forward(self, x):
if self.use_hf_backbone:
outputs = self.backbone(pixel_values=x)
features = outputs.feature_maps
return list(features)
else:
features = []
current = x
for layer in self.backbone:
current = layer(current)
features.append(current)
return features
class MattingDecoder(nn.Module):
"""U-Net decoder with attention gates"""
def __init__(self, use_attn=False, refine_channels=16):
super().__init__()
self.use_attn = use_attn
self.refine_channels = refine_channels
# Bottom convolution
self.conv_bottom = nn.Conv2d(768, 768, kernel_size=3, padding=1)
self.bn_bottom = nn.BatchNorm2d(768)
# Upsample + fuse with skip connections
self.conv_up3 = nn.Conv2d(768 + 384, 384, kernel_size=3, padding=1)
self.bn_up3 = nn.BatchNorm2d(384)
self.conv_up2 = nn.Conv2d(384 + 192, 192, kernel_size=3, padding=1)
self.bn_up2 = nn.BatchNorm2d(192)
self.conv_up1 = nn.Conv2d(192 + 96, 96, kernel_size=3, padding=1)
self.bn_up1 = nn.BatchNorm2d(96)
self.conv_out = nn.Conv2d(96, 1, kernel_size=3, padding=1)
# Detail refinement
self.refine_conv1 = nn.Conv2d(4, self.refine_channels, kernel_size=3, padding=1)
self.bn_refine1 = nn.BatchNorm2d(self.refine_channels)
self.refine_conv2 = nn.Conv2d(self.refine_channels, self.refine_channels, kernel_size=3, padding=1)
self.bn_refine2 = nn.BatchNorm2d(self.refine_channels)
self.refine_conv3 = nn.Conv2d(self.refine_channels, 1, kernel_size=3, padding=1)
# Attention gates
if self.use_attn:
self.reduce_768_to_384 = nn.Conv2d(768, 384, kernel_size=1)
self.reduce_384_to_192 = nn.Conv2d(384, 192, kernel_size=1)
self.reduce_192_to_96 = nn.Conv2d(192, 96, kernel_size=1)
self.gate_16 = nn.Conv2d(384, 384, kernel_size=1)
self.skip_16 = nn.Conv2d(384, 384, kernel_size=1)
self.gate_8 = nn.Conv2d(192, 192, kernel_size=1)
self.skip_8 = nn.Conv2d(192, 192, kernel_size=1)
self.gate_4 = nn.Conv2d(96, 96, kernel_size=1)
self.skip_4 = nn.Conv2d(96, 96, kernel_size=1)
def forward(self, features, original_image):
f1, f2, f3, f4 = features
# Bottom (1/32)
x = F.relu(self.bn_bottom(self.conv_bottom(f4)))
# 1/16 stage
x = F.interpolate(x, scale_factor=2.0, mode='nearest')
if self.use_attn:
x_reduced = self.reduce_768_to_384(x)
g = self.gate_16(x_reduced)
skip = self.skip_16(f3)
att = torch.sigmoid(g + skip)
f3 = f3 * att
x = torch.cat([x, f3], dim=1)
x = F.relu(self.bn_up3(self.conv_up3(x)))
# 1/8 stage
x = F.interpolate(x, scale_factor=2.0, mode='nearest')
if self.use_attn:
x_reduced = self.reduce_384_to_192(x)
g = self.gate_8(x_reduced)
skip = self.skip_8(f2)
att = torch.sigmoid(g + skip)
f2 = f2 * att
x = torch.cat([x, f2], dim=1)
x = F.relu(self.bn_up2(self.conv_up2(x)))
# 1/4 stage
x = F.interpolate(x, scale_factor=2.0, mode='nearest')
if self.use_attn:
x_reduced = self.reduce_192_to_96(x)
g = self.gate_4(x_reduced)
skip = self.skip_4(f1)
att = torch.sigmoid(g + skip)
f1 = f1 * att
x = torch.cat([x, f1], dim=1)
x = F.relu(self.bn_up1(self.conv_up1(x)))
# Upsample to full resolution and predict coarse alpha
x = F.interpolate(x, size=original_image.shape[-2:], mode='nearest')
coarse_alpha = self.conv_out(x)
# Detail refinement
refine_input = torch.cat([coarse_alpha, original_image], dim=1)
r = F.relu(self.bn_refine1(self.refine_conv1(refine_input)))
r = F.relu(self.bn_refine2(self.refine_conv2(r)))
refined_alpha = self.refine_conv3(r)
return torch.sigmoid(refined_alpha)
class SkyMaskingPipeline:
"""Main sky masking pipeline"""
def __init__(self, model_path="swin_small_patch4_window7_224.pt"):
self.transforms = Compose([
Resize(size=(512, 512)),
ToTensor(),
Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model_path = model_path
self.model = self._load_model()
print(f"🎯 Sky masking pipeline initialized on {self.device}")
def generate_mask(self, image: Image.Image) -> np.ndarray:
"""Generate sky mask from input image"""
if self.model is None:
raise RuntimeError("Model is not loaded.")
# Store original size
original_size = image.size
# Apply transforms and run inference
tensor = self.transforms(image).unsqueeze(0).to(self.device)
with torch.inference_mode():
output = self.model(tensor)
output = output.detach().cpu().numpy()
output = np.clip(output, a_min=0, a_max=1)
# Get alpha matte and resize to original dimensions
alpha_matte = np.squeeze(output, axis=0).squeeze()
mask_resized = cv2.resize(alpha_matte, original_size, interpolation=cv2.INTER_LINEAR)
# Convert to uint8
mask_uint8 = (mask_resized * 255).astype(np.uint8)
return mask_uint8
def _load_model(self):
"""Load model with downloaded weights"""
model = SwinMattingModel({
"encoder": {
"model_name": "microsoft/swin-small-patch4-window7-224"
},
"decoder": {
"use_attn": True,
"refine_channels": 16
}
})
self._load_checkpoint(model)
model.to(self.device)
model.eval()
return model
def _load_checkpoint(self, model):
"""Load checkpoint with error handling"""
try:
checkpoint = torch.load(self.model_path, map_location="cpu", weights_only=True)
except Exception as e:
print(f"Safe loading failed: {e}")
try:
checkpoint = torch.load(self.model_path, map_location="cpu", weights_only=False)
print("Warning: Used weights_only=False. Only use trusted model files.")
except Exception as e2:
print(f"Failed to load checkpoint: {e2}")
return
try:
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
if missing_keys:
print(f"Missing keys: {missing_keys}")
if unexpected_keys:
print(f"Unexpected keys: {unexpected_keys}")
print("✅ Model loaded successfully!")
except Exception as e:
print(f"Failed to load state dict: {e}")