import gradio as gr import torch import torch.nn as nn import segmentation_models_pytorch as smp from PIL import Image, ImageOps import numpy as np import albumentations as A from albumentations.pytorch import ToTensorV2 import os class ImprovedSkySegmentationModel(nn.Module): def __init__(self, encoder_name='resnet50', classes=1): super().__init__() self.model = smp.Unet( encoder_name=encoder_name, encoder_weights=None, # Don't load pretrained weights classes=classes, activation=None, ) def forward(self, x): output = self.model(x) return torch.sigmoid(output) # Global model variable model = None config = None device = None def load_model_once(): """Load the model once when the app starts""" global model, config, device if model is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' # For Hugging Face Spaces, the model file should be in the same directory model_path = "sky_segmentation_model.pt" # You'll upload this file if not os.path.exists(model_path): # Fallback for testing - you can remove this in production raise FileNotFoundError(f"Model file {model_path} not found. Please upload your trained model.") checkpoint = torch.load(model_path, map_location=device) config = checkpoint['config'] model = ImprovedSkySegmentationModel( encoder_name=config['encoder_name'], classes=config['classes'] ) model.load_state_dict(checkpoint['model_state_dict']) model.eval() model.to(device) print(f"Model loaded successfully on {device}") def preprocess_image(image, img_size=512): """Preprocess image for inference with EXIF orientation correction""" # Handle different input types if isinstance(image, str): # If image is a file path image = Image.open(image).convert('RGB') elif hasattr(image, 'convert'): # If image is already PIL Image image = image.convert('RGB') else: # Convert numpy array to PIL Image if necessary if isinstance(image, np.ndarray): image = Image.fromarray(image).convert('RGB') # Automatically correct orientation based on EXIF data image = ImageOps.exif_transpose(image) # Store original for display original_image = image.copy() transform = A.Compose([ A.Resize(img_size, img_size), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2(), ]) transformed = transform(image=np.array(image)) return transformed['image'].unsqueeze(0), original_image def predict_sky_mask(image_tensor): """Predict sky mask for an image""" global model, device with torch.no_grad(): image_tensor = image_tensor.to(device) prediction = model(image_tensor) if prediction.dim() == 4 and prediction.size(1) == 1: prediction = prediction.squeeze(1) return prediction.cpu().squeeze(0).numpy() def create_overlay(original_image, mask, alpha=0.4): """Create overlay of mask on original image""" if isinstance(original_image, Image.Image): original_image = np.array(original_image) # Resize mask to match original image size if mask.shape != original_image.shape[:2]: mask_resized = np.array(Image.fromarray((mask * 255).astype(np.uint8)).resize( (original_image.shape[1], original_image.shape[0]), Image.LANCZOS)) / 255.0 else: mask_resized = mask # Create colored overlay (blue for sky areas) overlay = original_image.copy().astype(float) colored_mask = np.zeros_like(original_image, dtype=float) colored_mask[:, :, 2] = mask_resized * 255 # Blue channel for sky # Blend original image with colored mask overlay = (1 - alpha) * overlay + alpha * colored_mask overlay = np.clip(overlay, 0, 255).astype(np.uint8) return overlay def segment_sky(image): """Main function for Gradio interface""" try: # Ensure model is loaded if model is None: load_model_once() # Preprocess image image_tensor, original_image = preprocess_image(image, config['img_size']) # Predict mask predicted_mask = predict_sky_mask(image_tensor) # Convert mask to PIL Image for display (0-255 range) mask_display = Image.fromarray((predicted_mask * 255).astype(np.uint8)) # Create overlay overlay = create_overlay(original_image, predicted_mask) overlay_display = Image.fromarray(overlay) return original_image, mask_display, overlay_display except Exception as e: error_img = Image.new('RGB', (512, 512), color='red') return error_img, error_img, error_img # Load model when the app starts try: load_model_once() model_status = "✅ Model loaded successfully!" except Exception as e: model_status = f"❌ Error loading model: {str(e)}" # Create Gradio interface with gr.Blocks(title="Sky Segmentation App", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # đŸŒ¤ī¸ Sky Segmentation App Upload an image and get an AI-powered sky segmentation mask! This model identifies sky regions in your photos. **How to use:** 1. Upload an image (JPG, PNG, etc.) 2. The model will automatically detect sky regions 3. View the original image, binary mask, and colored overlay """) # Model status gr.Markdown(f"**Model Status:** {model_status}") with gr.Row(): with gr.Column(scale=1): # Input input_image = gr.Image( label="📁 Upload Your Image", type="pil", height=400 ) segment_btn = gr.Button("🔍 Segment Sky", variant="primary", size="lg") with gr.Column(scale=2): with gr.Row(): original_output = gr.Image(label="📷 Original Image", height=300) mask_output = gr.Image(label="🎭 Sky Mask", height=300) overlay_output = gr.Image(label="đŸ”ĩ Sky Overlay", height=300) # Info section gr.Markdown(""" ### 📊 Understanding the Results: - **Original Image**: Your uploaded image - **Sky Mask**: Binary mask where white = sky, black = not sky - **Sky Overlay**: Original image with sky regions highlighted in blue ### â„šī¸ About the Model: This model uses a U-Net architecture with ResNet50 encoder, trained specifically for sky segmentation tasks. The model can handle various image orientations and lighting conditions. ### 🚀 Made with: - PyTorch & Segmentation Models PyTorch - Gradio for the interface - Hugging Face Spaces for hosting """) # Event handlers segment_btn.click( fn=segment_sky, inputs=[input_image], outputs=[original_output, mask_output, overlay_output] ) # Also trigger on image upload input_image.upload( fn=segment_sky, inputs=[input_image], outputs=[original_output, mask_output, overlay_output] ) # Launch the app if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False )