File size: 8,246 Bytes
c3da629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import streamlit as st
import sys
import os
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
import io
import cv2
from config import DATA_CONFIG, CLASS_NAMES

# Add current directory to Python path
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
    sys.path.append(current_dir)

from model_handler import ModelHandler
from data_processor import DataProcessor
from visualizer import Visualizer

class WebApp:
    def __init__(self):
        self.model_handler = ModelHandler()
        self.data_processor = DataProcessor()
        self.visualizer = Visualizer()
        
        # Load model and class indices
        if not self.model_handler.load_model():
            st.error("Failed to load model. Please check if models/ResNet50_eurosat.h5 exists.")
        if not self.model_handler.load_class_indices():
            st.error("Failed to load class indices. Please check if models/class_indices.npy exists.")
    
    def main_page(self):
        st.title("Project DEPI - Land Cover Classification")
        st.caption("Welcome to the Land Cover Classification System")
        
        # Display welcome image
        st.image("assets/satellite.jpg", use_column_width=True)
        
        # Image upload section
        uploaded_image = self.upload()
        
        if uploaded_image is not None:
            # Store the uploaded image in session state
            st.session_state.uploaded_image = uploaded_image
            
            col1, col2 = st.columns(2)
            with col1:
                st.subheader("Image Preview")
                st.image(uploaded_image, use_column_width=True)
            with col2:
                st.subheader("Classification")
                if st.button('Run Classification'):
                    self.run_classification(uploaded_image)

    def upload(self):
        """Handle image upload"""
        uploaded_file = st.file_uploader(
            "Choose an image file",
            type=DATA_CONFIG['allowed_formats']
        )
        return uploaded_file.read() if uploaded_file is not None else None

    def run_classification(self, image_data):
        """Run classification on uploaded image"""
        try:
            with st.spinner('Running classification...'):
                prediction = self.model_handler.predict(image_data)
                st.success(f"Classification Result: {prediction['class_name']}")
                st.write(f"Confidence: {prediction['confidence']:.2%}")
                # Use visualizer for confidence bar
                class_names = [self.model_handler.class_indices.get(str(i), f"Class_{i}") 
                              for i in range(len(prediction['all_predictions']))]
                fig = self.visualizer.plot_confidence_bar(class_names, prediction['all_predictions'])
                st.plotly_chart(fig, use_column_width=True)
        except Exception as e:
            st.error(f"Error during classification: {str(e)}")

    def charts_page(self):
        """Display various charts and visualizations"""
        st.title("Charts and Visualizations")
        
        tab1, tab2 = st.tabs(["Model evaluation", "Image Analysis"])
        
        with tab1:
            st.subheader("Model Performance")
            st.title("Model Training Results")
            st.image("assets/model_performance.jpg", caption="Training Progress Over Time", use_column_width=True)
            with st.expander("Accuracy Analysis"):
                st.markdown("""

                - **Training Accuracy**: Shows how well the model learns from training data

                - **Validation Accuracy**: Indicates real-world performance on unseen data

                - **Ideal Scenario**: Both metrics should increase and stabilize at similar values

                """)
            with st.expander("Loss Analysis"):
                st.markdown("""

                - **Training Loss**: Measures error reduction during training

                - **Validation Loss**: Tracks generalization error

                - **Healthy Pattern**: Both should decrease steadily without significant divergence

                """)
            st.header("What This Means")
            st.write("""

            The model is learning properly without overfitting.

            Both accuracy and loss show good progress.

            You could stop training earlier when it stops improving.

            """)
        with tab2:
            st.subheader("Image Analysis")
            if 'uploaded_image' in st.session_state:
                uploaded_image = st.session_state.uploaded_image
                st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
                try:
                    img = Image.open(io.BytesIO(uploaded_image))
                    img_array = np.array(img)
                    # Use visualizer for RGB histograms
                    st.subheader("RGB Color Histograms")
                    figs = self.visualizer.plot_rgb_histograms(img_array)
                    col1, col2, col3 = st.columns(3)
                    for i, col in enumerate([col1, col2, col3]):
                        with col:
                            st.pyplot(figs[i])
                    # Visualization selector
                    analysis_type = st.selectbox(
                        "Select Analysis Type",
                        ["Image Statistics", "Edge Detection", "Intensity Map"]
                    )
                    if analysis_type == "Image Statistics":
                        stats = self.visualizer.image_statistics(img_array)
                        for key, value in stats.items():
                            st.write(f"**{key}:** {value}")
                    elif analysis_type == "Edge Detection":
                        edges = self.visualizer.edge_detection(img_array)
                        st.image(edges, caption="Edge Detection", use_column_width=True)
                    elif analysis_type == "Intensity Map":
                        fig = self.visualizer.intensity_map(img_array)
                        st.plotly_chart(fig)
                except Exception as e:
                    st.error(f"Error processing image: {str(e)}")
            else:
                st.warning("Please upload an image in the main page first.")

    def classes_page(self):
        """Display detailed information about each class"""
        st.title("Land Cover Classes")
        
        # List of (class_name, description)
        class_info = [
            ("AnnualCrop", "Agricultural areas where crops are planted and harvested within a single year."),
            ("Forest", "Areas dominated by trees, forming a continuous canopy."),
            ("HerbaceousVegetation", "Areas covered by non-woody plants and grasses."),
            ("Highway", "Major roads and transportation infrastructure."),
            ("Industrial", "Areas containing factories, warehouses, and industrial facilities."),
            ("Pasture", "Land used for grazing livestock."),
            ("PermanentCrop", "Agricultural areas with long-term crops like orchards and vineyards."),
            ("Residential", "Areas containing houses and residential buildings."),
            ("River", "Natural watercourses and their immediate surroundings."),
            ("SeaLake", "Large bodies of water including seas and lakes.")
        ]
        
        cols = st.columns(2)
        for idx, (name, desc) in enumerate(class_info):
            with cols[idx % 2]:
                with st.expander(f"{name}"):
                    st.write(desc)

# Initialize and run the app
if __name__ == "__main__":
    app = WebApp()
    
    # Navigation
    st.sidebar.title("Navigation")
    page = st.sidebar.radio("Go to", ["Home", "Charts", "Classes"])
    
    if page == "Home":
        app.main_page()
    elif page == "Charts":
        app.charts_page()
    elif page == "Classes":
        app.classes_page()