Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import segmentation_models_pytorch as smp | |
| from PIL import Image, ImageOps | |
| import numpy as np | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| import os | |
| class ImprovedSkySegmentationModel(nn.Module): | |
| def __init__(self, encoder_name='resnet50', classes=1): | |
| super().__init__() | |
| self.model = smp.Unet( | |
| encoder_name=encoder_name, | |
| encoder_weights=None, # Don't load pretrained weights | |
| classes=classes, | |
| activation=None, | |
| ) | |
| def forward(self, x): | |
| output = self.model(x) | |
| return torch.sigmoid(output) | |
| # Global model variable | |
| model = None | |
| config = None | |
| device = None | |
| def load_model_once(): | |
| """Load the model once when the app starts""" | |
| global model, config, device | |
| if model is None: | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # For Hugging Face Spaces, the model file should be in the same directory | |
| model_path = "sky_segmentation_model.pt" # You'll upload this file | |
| if not os.path.exists(model_path): | |
| # Fallback for testing - you can remove this in production | |
| raise FileNotFoundError(f"Model file {model_path} not found. Please upload your trained model.") | |
| checkpoint = torch.load(model_path, map_location=device) | |
| config = checkpoint['config'] | |
| model = ImprovedSkySegmentationModel( | |
| encoder_name=config['encoder_name'], | |
| classes=config['classes'] | |
| ) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| model.to(device) | |
| print(f"Model loaded successfully on {device}") | |
| def preprocess_image(image, img_size=512): | |
| """Preprocess image for inference with EXIF orientation correction""" | |
| # Handle different input types | |
| if isinstance(image, str): | |
| # If image is a file path | |
| image = Image.open(image).convert('RGB') | |
| elif hasattr(image, 'convert'): | |
| # If image is already PIL Image | |
| image = image.convert('RGB') | |
| else: | |
| # Convert numpy array to PIL Image if necessary | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image).convert('RGB') | |
| # Automatically correct orientation based on EXIF data | |
| image = ImageOps.exif_transpose(image) | |
| # Store original for display | |
| original_image = image.copy() | |
| transform = A.Compose([ | |
| A.Resize(img_size, img_size), | |
| A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ToTensorV2(), | |
| ]) | |
| transformed = transform(image=np.array(image)) | |
| return transformed['image'].unsqueeze(0), original_image | |
| def predict_sky_mask(image_tensor): | |
| """Predict sky mask for an image""" | |
| global model, device | |
| with torch.no_grad(): | |
| image_tensor = image_tensor.to(device) | |
| prediction = model(image_tensor) | |
| if prediction.dim() == 4 and prediction.size(1) == 1: | |
| prediction = prediction.squeeze(1) | |
| return prediction.cpu().squeeze(0).numpy() | |
| def create_overlay(original_image, mask, alpha=0.4): | |
| """Create overlay of mask on original image""" | |
| if isinstance(original_image, Image.Image): | |
| original_image = np.array(original_image) | |
| # Resize mask to match original image size | |
| if mask.shape != original_image.shape[:2]: | |
| mask_resized = np.array(Image.fromarray((mask * 255).astype(np.uint8)).resize( | |
| (original_image.shape[1], original_image.shape[0]), Image.LANCZOS)) / 255.0 | |
| else: | |
| mask_resized = mask | |
| # Create colored overlay (blue for sky areas) | |
| overlay = original_image.copy().astype(float) | |
| colored_mask = np.zeros_like(original_image, dtype=float) | |
| colored_mask[:, :, 2] = mask_resized * 255 # Blue channel for sky | |
| # Blend original image with colored mask | |
| overlay = (1 - alpha) * overlay + alpha * colored_mask | |
| overlay = np.clip(overlay, 0, 255).astype(np.uint8) | |
| return overlay | |
| def segment_sky(image): | |
| """Main function for Gradio interface""" | |
| try: | |
| # Ensure model is loaded | |
| if model is None: | |
| load_model_once() | |
| # Preprocess image | |
| image_tensor, original_image = preprocess_image(image, config['img_size']) | |
| # Predict mask | |
| predicted_mask = predict_sky_mask(image_tensor) | |
| # Convert mask to PIL Image for display (0-255 range) | |
| mask_display = Image.fromarray((predicted_mask * 255).astype(np.uint8)) | |
| # Create overlay | |
| overlay = create_overlay(original_image, predicted_mask) | |
| overlay_display = Image.fromarray(overlay) | |
| return original_image, mask_display, overlay_display | |
| except Exception as e: | |
| error_img = Image.new('RGB', (512, 512), color='red') | |
| return error_img, error_img, error_img | |
| # Load model when the app starts | |
| try: | |
| load_model_once() | |
| model_status = "β Model loaded successfully!" | |
| except Exception as e: | |
| model_status = f"β Error loading model: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="Sky Segmentation App", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π€οΈ Sky Segmentation App | |
| Upload an image and get an AI-powered sky segmentation mask! This model identifies sky regions in your photos. | |
| **How to use:** | |
| 1. Upload an image (JPG, PNG, etc.) | |
| 2. The model will automatically detect sky regions | |
| 3. View the original image, binary mask, and colored overlay | |
| """) | |
| # Model status | |
| gr.Markdown(f"**Model Status:** {model_status}") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input | |
| input_image = gr.Image( | |
| label="π Upload Your Image", | |
| type="pil", | |
| height=400 | |
| ) | |
| segment_btn = gr.Button("π Segment Sky", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| original_output = gr.Image(label="π· Original Image", height=300) | |
| mask_output = gr.Image(label="π Sky Mask", height=300) | |
| overlay_output = gr.Image(label="π΅ Sky Overlay", height=300) | |
| # Info section | |
| gr.Markdown(""" | |
| ### π Understanding the Results: | |
| - **Original Image**: Your uploaded image | |
| - **Sky Mask**: Binary mask where white = sky, black = not sky | |
| - **Sky Overlay**: Original image with sky regions highlighted in blue | |
| ### βΉοΈ About the Model: | |
| This model uses a U-Net architecture with ResNet50 encoder, trained specifically for sky segmentation tasks. | |
| The model can handle various image orientations and lighting conditions. | |
| ### π Made with: | |
| - PyTorch & Segmentation Models PyTorch | |
| - Gradio for the interface | |
| - Hugging Face Spaces for hosting | |
| """) | |
| # Event handlers | |
| segment_btn.click( | |
| fn=segment_sky, | |
| inputs=[input_image], | |
| outputs=[original_output, mask_output, overlay_output] | |
| ) | |
| # Also trigger on image upload | |
| input_image.upload( | |
| fn=segment_sky, | |
| inputs=[input_image], | |
| outputs=[original_output, mask_output, overlay_output] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) | |