File size: 7,791 Bytes
c44101e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import gradio as gr
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from PIL import Image, ImageOps
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os

class ImprovedSkySegmentationModel(nn.Module):
    def __init__(self, encoder_name='resnet50', classes=1):
        super().__init__()
        self.model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=None,  # Don't load pretrained weights
            classes=classes,
            activation=None,
        )

    def forward(self, x):
        output = self.model(x)
        return torch.sigmoid(output)

# Global model variable
model = None
config = None
device = None

def load_model_once():
    """Load the model once when the app starts"""
    global model, config, device
    
    if model is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        # For Hugging Face Spaces, the model file should be in the same directory
        model_path = "sky_segmentation_model.pt"  # You'll upload this file
        
        if not os.path.exists(model_path):
            # Fallback for testing - you can remove this in production
            raise FileNotFoundError(f"Model file {model_path} not found. Please upload your trained model.")
        
        checkpoint = torch.load(model_path, map_location=device)
        
        config = checkpoint['config']
        model = ImprovedSkySegmentationModel(
            encoder_name=config['encoder_name'],
            classes=config['classes']
        )
        
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        model.to(device)
        
        print(f"Model loaded successfully on {device}")

def preprocess_image(image, img_size=512):
    """Preprocess image for inference with EXIF orientation correction"""
    # Handle different input types
    if isinstance(image, str):
        # If image is a file path
        image = Image.open(image).convert('RGB')
    elif hasattr(image, 'convert'):
        # If image is already PIL Image
        image = image.convert('RGB')
    else:
        # Convert numpy array to PIL Image if necessary
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image).convert('RGB')
    
    # Automatically correct orientation based on EXIF data
    image = ImageOps.exif_transpose(image)
    
    # Store original for display
    original_image = image.copy()
    
    transform = A.Compose([
        A.Resize(img_size, img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    transformed = transform(image=np.array(image))
    return transformed['image'].unsqueeze(0), original_image

def predict_sky_mask(image_tensor):
    """Predict sky mask for an image"""
    global model, device
    
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        prediction = model(image_tensor)
        
        if prediction.dim() == 4 and prediction.size(1) == 1:
            prediction = prediction.squeeze(1)
            
        return prediction.cpu().squeeze(0).numpy()

def create_overlay(original_image, mask, alpha=0.4):
    """Create overlay of mask on original image"""
    if isinstance(original_image, Image.Image):
        original_image = np.array(original_image)
    
    # Resize mask to match original image size
    if mask.shape != original_image.shape[:2]:
        mask_resized = np.array(Image.fromarray((mask * 255).astype(np.uint8)).resize(
            (original_image.shape[1], original_image.shape[0]), Image.LANCZOS)) / 255.0
    else:
        mask_resized = mask
    
    # Create colored overlay (blue for sky areas)
    overlay = original_image.copy().astype(float)
    colored_mask = np.zeros_like(original_image, dtype=float)
    colored_mask[:, :, 2] = mask_resized * 255  # Blue channel for sky
    
    # Blend original image with colored mask
    overlay = (1 - alpha) * overlay + alpha * colored_mask
    overlay = np.clip(overlay, 0, 255).astype(np.uint8)
    
    return overlay

def segment_sky(image):
    """Main function for Gradio interface"""
    try:
        # Ensure model is loaded
        if model is None:
            load_model_once()
        
        # Preprocess image
        image_tensor, original_image = preprocess_image(image, config['img_size'])
        
        # Predict mask
        predicted_mask = predict_sky_mask(image_tensor)
        
        # Convert mask to PIL Image for display (0-255 range)
        mask_display = Image.fromarray((predicted_mask * 255).astype(np.uint8))
        
        # Create overlay
        overlay = create_overlay(original_image, predicted_mask)
        overlay_display = Image.fromarray(overlay)
        
        return original_image, mask_display, overlay_display
        
    except Exception as e:
        error_img = Image.new('RGB', (512, 512), color='red')
        return error_img, error_img, error_img

# Load model when the app starts
try:
    load_model_once()
    model_status = "βœ… Model loaded successfully!"
except Exception as e:
    model_status = f"❌ Error loading model: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="Sky Segmentation App", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""

    # 🌀️ Sky Segmentation App

    

    Upload an image and get an AI-powered sky segmentation mask! This model identifies sky regions in your photos.

    

    **How to use:**

    1. Upload an image (JPG, PNG, etc.)

    2. The model will automatically detect sky regions

    3. View the original image, binary mask, and colored overlay

    """)
    
    # Model status
    gr.Markdown(f"**Model Status:** {model_status}")
    
    with gr.Row():
        with gr.Column(scale=1):
            # Input
            input_image = gr.Image(
                label="πŸ“ Upload Your Image",
                type="pil",
                height=400
            )
            
            segment_btn = gr.Button("πŸ” Segment Sky", variant="primary", size="lg")
        
        with gr.Column(scale=2):
            with gr.Row():
                original_output = gr.Image(label="πŸ“· Original Image", height=300)
                mask_output = gr.Image(label="🎭 Sky Mask", height=300)
                overlay_output = gr.Image(label="πŸ”΅ Sky Overlay", height=300)
    
    # Info section
    gr.Markdown("""

    ### πŸ“Š Understanding the Results:

    - **Original Image**: Your uploaded image

    - **Sky Mask**: Binary mask where white = sky, black = not sky

    - **Sky Overlay**: Original image with sky regions highlighted in blue

    

    ### ℹ️ About the Model:

    This model uses a U-Net architecture with ResNet50 encoder, trained specifically for sky segmentation tasks.

    The model can handle various image orientations and lighting conditions.

    

    ### πŸš€ Made with:

    - PyTorch & Segmentation Models PyTorch

    - Gradio for the interface  

    - Hugging Face Spaces for hosting

    """)
    
    # Event handlers
    segment_btn.click(
        fn=segment_sky,
        inputs=[input_image],
        outputs=[original_output, mask_output, overlay_output]
    )
    
    # Also trigger on image upload
    input_image.upload(
        fn=segment_sky,
        inputs=[input_image], 
        outputs=[original_output, mask_output, overlay_output]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False
    )