Mohaaaa's picture
Upload 4 files
c44101e verified
import gradio as gr
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from PIL import Image, ImageOps
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
class ImprovedSkySegmentationModel(nn.Module):
def __init__(self, encoder_name='resnet50', classes=1):
super().__init__()
self.model = smp.Unet(
encoder_name=encoder_name,
encoder_weights=None, # Don't load pretrained weights
classes=classes,
activation=None,
)
def forward(self, x):
output = self.model(x)
return torch.sigmoid(output)
# Global model variable
model = None
config = None
device = None
def load_model_once():
"""Load the model once when the app starts"""
global model, config, device
if model is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# For Hugging Face Spaces, the model file should be in the same directory
model_path = "sky_segmentation_model.pt" # You'll upload this file
if not os.path.exists(model_path):
# Fallback for testing - you can remove this in production
raise FileNotFoundError(f"Model file {model_path} not found. Please upload your trained model.")
checkpoint = torch.load(model_path, map_location=device)
config = checkpoint['config']
model = ImprovedSkySegmentationModel(
encoder_name=config['encoder_name'],
classes=config['classes']
)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model.to(device)
print(f"Model loaded successfully on {device}")
def preprocess_image(image, img_size=512):
"""Preprocess image for inference with EXIF orientation correction"""
# Handle different input types
if isinstance(image, str):
# If image is a file path
image = Image.open(image).convert('RGB')
elif hasattr(image, 'convert'):
# If image is already PIL Image
image = image.convert('RGB')
else:
# Convert numpy array to PIL Image if necessary
if isinstance(image, np.ndarray):
image = Image.fromarray(image).convert('RGB')
# Automatically correct orientation based on EXIF data
image = ImageOps.exif_transpose(image)
# Store original for display
original_image = image.copy()
transform = A.Compose([
A.Resize(img_size, img_size),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])
transformed = transform(image=np.array(image))
return transformed['image'].unsqueeze(0), original_image
def predict_sky_mask(image_tensor):
"""Predict sky mask for an image"""
global model, device
with torch.no_grad():
image_tensor = image_tensor.to(device)
prediction = model(image_tensor)
if prediction.dim() == 4 and prediction.size(1) == 1:
prediction = prediction.squeeze(1)
return prediction.cpu().squeeze(0).numpy()
def create_overlay(original_image, mask, alpha=0.4):
"""Create overlay of mask on original image"""
if isinstance(original_image, Image.Image):
original_image = np.array(original_image)
# Resize mask to match original image size
if mask.shape != original_image.shape[:2]:
mask_resized = np.array(Image.fromarray((mask * 255).astype(np.uint8)).resize(
(original_image.shape[1], original_image.shape[0]), Image.LANCZOS)) / 255.0
else:
mask_resized = mask
# Create colored overlay (blue for sky areas)
overlay = original_image.copy().astype(float)
colored_mask = np.zeros_like(original_image, dtype=float)
colored_mask[:, :, 2] = mask_resized * 255 # Blue channel for sky
# Blend original image with colored mask
overlay = (1 - alpha) * overlay + alpha * colored_mask
overlay = np.clip(overlay, 0, 255).astype(np.uint8)
return overlay
def segment_sky(image):
"""Main function for Gradio interface"""
try:
# Ensure model is loaded
if model is None:
load_model_once()
# Preprocess image
image_tensor, original_image = preprocess_image(image, config['img_size'])
# Predict mask
predicted_mask = predict_sky_mask(image_tensor)
# Convert mask to PIL Image for display (0-255 range)
mask_display = Image.fromarray((predicted_mask * 255).astype(np.uint8))
# Create overlay
overlay = create_overlay(original_image, predicted_mask)
overlay_display = Image.fromarray(overlay)
return original_image, mask_display, overlay_display
except Exception as e:
error_img = Image.new('RGB', (512, 512), color='red')
return error_img, error_img, error_img
# Load model when the app starts
try:
load_model_once()
model_status = "βœ… Model loaded successfully!"
except Exception as e:
model_status = f"❌ Error loading model: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Sky Segmentation App", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🌀️ Sky Segmentation App
Upload an image and get an AI-powered sky segmentation mask! This model identifies sky regions in your photos.
**How to use:**
1. Upload an image (JPG, PNG, etc.)
2. The model will automatically detect sky regions
3. View the original image, binary mask, and colored overlay
""")
# Model status
gr.Markdown(f"**Model Status:** {model_status}")
with gr.Row():
with gr.Column(scale=1):
# Input
input_image = gr.Image(
label="πŸ“ Upload Your Image",
type="pil",
height=400
)
segment_btn = gr.Button("πŸ” Segment Sky", variant="primary", size="lg")
with gr.Column(scale=2):
with gr.Row():
original_output = gr.Image(label="πŸ“· Original Image", height=300)
mask_output = gr.Image(label="🎭 Sky Mask", height=300)
overlay_output = gr.Image(label="πŸ”΅ Sky Overlay", height=300)
# Info section
gr.Markdown("""
### πŸ“Š Understanding the Results:
- **Original Image**: Your uploaded image
- **Sky Mask**: Binary mask where white = sky, black = not sky
- **Sky Overlay**: Original image with sky regions highlighted in blue
### ℹ️ About the Model:
This model uses a U-Net architecture with ResNet50 encoder, trained specifically for sky segmentation tasks.
The model can handle various image orientations and lighting conditions.
### πŸš€ Made with:
- PyTorch & Segmentation Models PyTorch
- Gradio for the interface
- Hugging Face Spaces for hosting
""")
# Event handlers
segment_btn.click(
fn=segment_sky,
inputs=[input_image],
outputs=[original_output, mask_output, overlay_output]
)
# Also trigger on image upload
input_image.upload(
fn=segment_sky,
inputs=[input_image],
outputs=[original_output, mask_output, overlay_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)