Spaces:
Sleeping
Sleeping
File size: 7,791 Bytes
c44101e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | import gradio as gr
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from PIL import Image, ImageOps
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
class ImprovedSkySegmentationModel(nn.Module):
def __init__(self, encoder_name='resnet50', classes=1):
super().__init__()
self.model = smp.Unet(
encoder_name=encoder_name,
encoder_weights=None, # Don't load pretrained weights
classes=classes,
activation=None,
)
def forward(self, x):
output = self.model(x)
return torch.sigmoid(output)
# Global model variable
model = None
config = None
device = None
def load_model_once():
"""Load the model once when the app starts"""
global model, config, device
if model is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# For Hugging Face Spaces, the model file should be in the same directory
model_path = "sky_segmentation_model.pt" # You'll upload this file
if not os.path.exists(model_path):
# Fallback for testing - you can remove this in production
raise FileNotFoundError(f"Model file {model_path} not found. Please upload your trained model.")
checkpoint = torch.load(model_path, map_location=device)
config = checkpoint['config']
model = ImprovedSkySegmentationModel(
encoder_name=config['encoder_name'],
classes=config['classes']
)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model.to(device)
print(f"Model loaded successfully on {device}")
def preprocess_image(image, img_size=512):
"""Preprocess image for inference with EXIF orientation correction"""
# Handle different input types
if isinstance(image, str):
# If image is a file path
image = Image.open(image).convert('RGB')
elif hasattr(image, 'convert'):
# If image is already PIL Image
image = image.convert('RGB')
else:
# Convert numpy array to PIL Image if necessary
if isinstance(image, np.ndarray):
image = Image.fromarray(image).convert('RGB')
# Automatically correct orientation based on EXIF data
image = ImageOps.exif_transpose(image)
# Store original for display
original_image = image.copy()
transform = A.Compose([
A.Resize(img_size, img_size),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])
transformed = transform(image=np.array(image))
return transformed['image'].unsqueeze(0), original_image
def predict_sky_mask(image_tensor):
"""Predict sky mask for an image"""
global model, device
with torch.no_grad():
image_tensor = image_tensor.to(device)
prediction = model(image_tensor)
if prediction.dim() == 4 and prediction.size(1) == 1:
prediction = prediction.squeeze(1)
return prediction.cpu().squeeze(0).numpy()
def create_overlay(original_image, mask, alpha=0.4):
"""Create overlay of mask on original image"""
if isinstance(original_image, Image.Image):
original_image = np.array(original_image)
# Resize mask to match original image size
if mask.shape != original_image.shape[:2]:
mask_resized = np.array(Image.fromarray((mask * 255).astype(np.uint8)).resize(
(original_image.shape[1], original_image.shape[0]), Image.LANCZOS)) / 255.0
else:
mask_resized = mask
# Create colored overlay (blue for sky areas)
overlay = original_image.copy().astype(float)
colored_mask = np.zeros_like(original_image, dtype=float)
colored_mask[:, :, 2] = mask_resized * 255 # Blue channel for sky
# Blend original image with colored mask
overlay = (1 - alpha) * overlay + alpha * colored_mask
overlay = np.clip(overlay, 0, 255).astype(np.uint8)
return overlay
def segment_sky(image):
"""Main function for Gradio interface"""
try:
# Ensure model is loaded
if model is None:
load_model_once()
# Preprocess image
image_tensor, original_image = preprocess_image(image, config['img_size'])
# Predict mask
predicted_mask = predict_sky_mask(image_tensor)
# Convert mask to PIL Image for display (0-255 range)
mask_display = Image.fromarray((predicted_mask * 255).astype(np.uint8))
# Create overlay
overlay = create_overlay(original_image, predicted_mask)
overlay_display = Image.fromarray(overlay)
return original_image, mask_display, overlay_display
except Exception as e:
error_img = Image.new('RGB', (512, 512), color='red')
return error_img, error_img, error_img
# Load model when the app starts
try:
load_model_once()
model_status = "β
Model loaded successfully!"
except Exception as e:
model_status = f"β Error loading model: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Sky Segmentation App", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# π€οΈ Sky Segmentation App
Upload an image and get an AI-powered sky segmentation mask! This model identifies sky regions in your photos.
**How to use:**
1. Upload an image (JPG, PNG, etc.)
2. The model will automatically detect sky regions
3. View the original image, binary mask, and colored overlay
""")
# Model status
gr.Markdown(f"**Model Status:** {model_status}")
with gr.Row():
with gr.Column(scale=1):
# Input
input_image = gr.Image(
label="π Upload Your Image",
type="pil",
height=400
)
segment_btn = gr.Button("π Segment Sky", variant="primary", size="lg")
with gr.Column(scale=2):
with gr.Row():
original_output = gr.Image(label="π· Original Image", height=300)
mask_output = gr.Image(label="π Sky Mask", height=300)
overlay_output = gr.Image(label="π΅ Sky Overlay", height=300)
# Info section
gr.Markdown("""
### π Understanding the Results:
- **Original Image**: Your uploaded image
- **Sky Mask**: Binary mask where white = sky, black = not sky
- **Sky Overlay**: Original image with sky regions highlighted in blue
### βΉοΈ About the Model:
This model uses a U-Net architecture with ResNet50 encoder, trained specifically for sky segmentation tasks.
The model can handle various image orientations and lighting conditions.
### π Made with:
- PyTorch & Segmentation Models PyTorch
- Gradio for the interface
- Hugging Face Spaces for hosting
""")
# Event handlers
segment_btn.click(
fn=segment_sky,
inputs=[input_image],
outputs=[original_output, mask_output, overlay_output]
)
# Also trigger on image upload
input_image.upload(
fn=segment_sky,
inputs=[input_image],
outputs=[original_output, mask_output, overlay_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
|