Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from PIL import Image | |
| import numpy as np | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| # EuroSAT class names (10 land cover classes) | |
| EUROSAT_CLASSES = [ | |
| "AnnualCrop", | |
| "Forest", | |
| "HerbaceousVegetation", | |
| "Highway", | |
| "Industrial", | |
| "Pasture", | |
| "PermanentCrop", | |
| "Residential", | |
| "River", | |
| "SeaLake" | |
| ] | |
| # Class descriptions for better user understanding | |
| CLASS_DESCRIPTIONS = { | |
| "AnnualCrop": "πΎ Agricultural land with annual crops", | |
| "Forest": "π² Dense forest areas with trees", | |
| "HerbaceousVegetation": "πΏ Areas with herbaceous vegetation", | |
| "Highway": "π£οΈ Major roads and highway infrastructure", | |
| "Industrial": "π Industrial areas and facilities", | |
| "Pasture": "π Pasture land for livestock", | |
| "PermanentCrop": "π Permanent crop areas (vineyards, orchards)", | |
| "Residential": "ποΈ Residential areas and neighborhoods", | |
| "River": "ποΈ Rivers and waterways", | |
| "SeaLake": "ποΈ Seas, lakes, and large water bodies" | |
| } | |
| class EuroSATClassifier: | |
| def __init__(self, model_name="Adilbai/EuroSAT-Swin"): | |
| self.model_name = model_name | |
| self.processor = None | |
| self.model = None | |
| self.load_model() | |
| def load_model(self): | |
| """Load the model and processor""" | |
| try: | |
| self.processor = AutoImageProcessor.from_pretrained(self.model_name) | |
| self.model = AutoModelForImageClassification.from_pretrained(self.model_name) | |
| self.model.eval() | |
| print(f"β Model {self.model_name} loaded successfully!") | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| # Fallback to a generic model if the specific one fails | |
| self.processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224") | |
| self.model = AutoModelForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224") | |
| def predict(self, image): | |
| """Make prediction on the input image""" | |
| if image is None: | |
| return None, None, "Please upload an image first!" | |
| try: | |
| # Preprocess the image | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| # Get top predictions | |
| probabilities = predictions[0].numpy() | |
| # Create results dictionary | |
| results = {} | |
| for i, class_name in enumerate(EUROSAT_CLASSES): | |
| if i < len(probabilities): | |
| results[class_name] = float(probabilities[i]) | |
| else: | |
| results[class_name] = 0.0 | |
| # Sort by confidence | |
| sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True)) | |
| # Get top prediction | |
| top_class = list(sorted_results.keys())[0] | |
| top_confidence = list(sorted_results.values())[0] | |
| # Create confidence plot | |
| confidence_plot = self.create_confidence_plot(sorted_results) | |
| # Format result text | |
| result_text = f"π― **Prediction: {top_class}**\n\n" | |
| result_text += f"π **Confidence: {top_confidence:.1%}**\n\n" | |
| result_text += f"π **Description: {CLASS_DESCRIPTIONS.get(top_class, 'Land cover classification')}**\n\n" | |
| result_text += "### Top 3 Predictions:\n" | |
| for i, (class_name, confidence) in enumerate(list(sorted_results.items())[:3]): | |
| result_text += f"{i+1}. **{class_name}**: {confidence:.1%}\n" | |
| return sorted_results, confidence_plot, result_text | |
| except Exception as e: | |
| error_msg = f"β Error during prediction: {str(e)}" | |
| return None, None, error_msg | |
| def create_confidence_plot(self, results): | |
| """Create a clean confidence plot using Plotly""" | |
| classes = list(results.keys()) | |
| confidences = [results[cls] * 100 for cls in classes] | |
| # Use consistent solid colors (green for top, blue for others) | |
| colors = ['#2E8B57' if i == 0 else '#4682B4' for i in range(len(classes))] | |
| fig = go.Figure(data=[ | |
| go.Bar( | |
| x=confidences, | |
| y=classes, | |
| orientation='h', | |
| marker_color=colors, | |
| text=[f'{conf:.1f}%' for conf in confidences], | |
| textposition='inside', | |
| textfont=dict(color='white', size=12), | |
| ) | |
| ]) | |
| fig.update_layout( | |
| title="π― Classification Confidence Scores", | |
| xaxis_title="Confidence (%)", | |
| yaxis_title="Land Cover Classes", | |
| height=500, | |
| margin=dict(l=10, r=10, t=40, b=10), | |
| plot_bgcolor='white', | |
| paper_bgcolor='white', | |
| font=dict(family="Arial", size=12, color="#333"), | |
| xaxis=dict( | |
| gridcolor='rgba(0,0,0,0.05)', | |
| showgrid=True, | |
| range=[0, 100] | |
| ), | |
| yaxis=dict( | |
| gridcolor='rgba(0,0,0,0.05)', | |
| showgrid=True, | |
| autorange="reversed" | |
| ) | |
| ) | |
| return fig | |
| # Initialize the classifier | |
| classifier = EuroSATClassifier() | |
| def classify_image(image): | |
| """Main classification function for Gradio interface""" | |
| return classifier.predict(image) | |
| def get_sample_images(): | |
| """Return some sample image descriptions""" | |
| return """ | |
| ### πΌοΈ Try these types of satellite images: | |
| - **πΎ Agricultural fields** - Crop lands and farmland | |
| - **π² Forest areas** - Dense tree coverage | |
| - **ποΈ Residential zones** - Urban neighborhoods | |
| - **π Industrial sites** - Factories and industrial areas | |
| - **π£οΈ Highway systems** - Major roads and intersections | |
| - **π§ Water bodies** - Rivers, lakes, and seas | |
| - **πΏ Natural vegetation** - Grasslands and natural areas | |
| Upload a satellite/aerial image to see the land cover classification! | |
| """ | |
| # Custom CSS for better styling | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| } | |
| .main-header { | |
| text-align: center; | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| padding: 2rem; | |
| border-radius: 10px; | |
| margin-bottom: 2rem; | |
| } | |
| .upload-area { | |
| border: 2px dashed #667eea; | |
| border-radius: 10px; | |
| padding: 2rem; | |
| text-align: center; | |
| background: rgba(0, 0, 0, 0.43); | |
| } | |
| .result-text { | |
| background: #070605; | |
| padding: 1.5rem; | |
| border-radius: 10px; | |
| border-left: 4px solid #667eea; | |
| } | |
| """ | |
| # Create the Gradio interface | |
| with gr.Blocks(css=custom_css, title="π°οΈ EuroSAT Land Cover Classifier") as demo: | |
| gr.HTML(""" | |
| <div class="main-header"> | |
| <h1>π°οΈ EuroSAT Land Cover Classifier</h1> | |
| <p>Advanced satellite image classification using Swin Transformer</p> | |
| <p><strong>Model:</strong> Adilbai/EuroSAT-Swin | <strong>Dataset:</strong> EuroSAT (10 land cover classes)</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML("<h3>π€ Upload Satellite Image</h3>") | |
| image_input = gr.Image( | |
| label="Upload a satellite/aerial image", | |
| type="pil", | |
| height=400, | |
| elem_classes="upload-area" | |
| ) | |
| classify_btn = gr.Button( | |
| "π Classify Land Cover", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| gr.HTML("<div style='margin-top: 2rem;'>") | |
| gr.Markdown(get_sample_images()) | |
| gr.HTML("</div>") | |
| with gr.Column(scale=1): | |
| gr.HTML("<h3>π Classification Results</h3>") | |
| result_text = gr.Markdown( | |
| value="Upload an image and click 'Classify Land Cover' to see results!", | |
| elem_classes="result-text" | |
| ) | |
| confidence_plot = gr.Plot( | |
| label="Confidence Scores", | |
| ) | |
| # Hidden component to store raw results | |
| raw_results = gr.JSON(visible=False) | |
| # Event handlers | |
| classify_btn.click( | |
| fn=classify_image, | |
| inputs=[image_input], | |
| outputs=[raw_results, confidence_plot, result_text] | |
| ) | |
| # Also trigger on image upload | |
| image_input.change( | |
| fn=classify_image, | |
| inputs=[image_input], | |
| outputs=[raw_results, confidence_plot, result_text] | |
| ) | |
| # Footer | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-top: 3rem; padding: 2rem; background: #070605; border-radius: 10px;"> | |
| <h4>π¬ About This Model</h4> | |
| <p>This classifier uses the <strong>Swin Transformer</strong> architecture trained on the <strong>EuroSAT dataset</strong>.</p> | |
| <p>The EuroSAT dataset contains <strong>27,000 satellite images</strong> from <strong>34 European countries</strong>, covering <strong>10 different land cover classes</strong>.</p> | |
| <p>Perfect for environmental monitoring, urban planning, and agricultural analysis! π</p> | |
| <br> | |
| <p><strong>Model:</strong> <a href="https://huggingface.co/Adilbai/EuroSAT-Swin" target="_blank">Adilbai/EuroSAT-Swin</a></p> | |
| </div> | |
| """) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=True, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) |