File size: 6,490 Bytes
7992750
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import streamlit as st
import torch
from PIL import Image
import numpy as np
from inference import get_inference_model
import json

# Page config
st.set_page_config(
    page_title="🌌 Astronomy Image Classification",
    page_icon="🌌",
    layout="wide"
)

# Title
st.title("🌌 Astronomy Image Classification")
st.markdown("Classify astronomy images into 6 categories using ensemble of ResNet50 and DenseNet121 models")

# Sidebar
st.sidebar.title("📊 Model Info")
st.sidebar.markdown("""

**Models**: ResNet50 + DenseNet121 Ensemble  

**ResNet50 Accuracy**: 64.86%  

**DenseNet121 Accuracy**: 63.96%  

**Ensemble**: Higher accuracy than individual models  

**Classes**: 6 astronomy categories  

**Input Size**: 224x224 pixels

""")

# Load model
@st.cache_resource
def load_model():
    try:
        return get_inference_model()
    except Exception as e:
        st.error(f"Error loading model: {e}")
        return None

# Main interface
model = load_model()

if model is not None:
    # Upload image
    uploaded_file = st.file_uploader(
        "Upload an astronomy image",
        type=['jpg', 'jpeg', 'png'],
        help="Upload an image of constellation, cosmos, galaxies, nebula, planets, or stars"
    )
    
    if uploaded_file is not None:
        # Display image
        col1, col2 = st.columns([1, 1])
        
        with col1:
            image = Image.open(uploaded_file)
            st.image(image, caption="Uploaded Image", use_column_width=True)
        
        with col2:
            # Make prediction
            with st.spinner("Analyzing image with ensemble models..."):
                result = model.predict(image)
            
            # Display results
            st.subheader("🎯 Ensemble Prediction Results")
            
            # Main prediction
            predicted_class = result["predicted_class"]
            confidence = result["confidence"]
            
            # Color code based on confidence
            if confidence > 0.8:
                color = "��"
                status = "High Confidence"
            elif confidence > 0.6:
                color = "🟡"
                status = "Medium Confidence"
            else:
                color = "🔴"
                status = "Low Confidence"
            
            st.markdown(f"""

            **{color} Predicted Class**: {predicted_class}  

            **Confidence**: {confidence:.3f}  

            **Status**: {status}

            """)
            
            # Progress bar
            st.progress(confidence)
            
            # Individual model results
            if "individual_results" in result:
                st.subheader("🔍 Individual Model Results")
                individual_results = result["individual_results"]
                
                for model_name, model_result in individual_results.items():
                    model_confidence = model_result["confidence"]
                    model_prediction = model_result["predicted_class"]
                    
                    # Color code individual results
                    if model_confidence > 0.8:
                        model_color = "🟢"
                    elif model_confidence > 0.6:
                        model_color = "🟡"
                    else:
                        model_color = "🔴"
                    
                    st.write(f"**{model_name}**: {model_color} {model_prediction} ({model_confidence:.3f})")
            
            # All probabilities
            st.subheader("�� All Class Probabilities")
            probabilities = result["probabilities"]
            
            # Create a more visual representation
            for class_name, prob in sorted(probabilities.items(), key=lambda x: x[1], reverse=True):
                # Create a bar chart for each probability
                col_prob, col_bar = st.columns([2, 3])
                
                with col_prob:
                    st.write(f"**{class_name}**")
                
                with col_bar:
                    st.progress(prob)
                    st.write(f"{prob:.3f}")
    
    # Sample images section
    st.markdown("---")
    st.subheader("📸 Sample Images")
    
    # Create sample images with better descriptions
    sample_cols = st.columns(3)
    
    with sample_cols[0]:
        st.markdown("**🌟 Constellation**")
        st.info("Star patterns forming recognizable shapes like Orion, Big Dipper, etc.")
    
    with sample_cols[1]:
        st.markdown("**🌌 Galaxies**")
        st.info("Spiral, elliptical, or irregular galaxies like Andromeda, Milky Way")
    
    with sample_cols[2]:
        st.markdown("**�� Nebula**")
        st.info("Gas clouds and stellar nurseries like Orion Nebula, Eagle Nebula")
    
    # Second row
    sample_cols2 = st.columns(3)
    
    with sample_cols2[0]:
        st.markdown("**🪐 Planets**")
        st.info("Solar system planets like Jupiter, Saturn, Mars, Earth")
    
    with sample_cols2[1]:
        st.markdown("**⭐ Stars**")
        st.info("Individual stars, stellar objects, and stellar phenomena")
    
    with sample_cols2[2]:
        st.markdown("**🌠 Cosmos**")
        st.info("General space scenes, cosmic phenomena, and deep space")
    
    # Model comparison
    st.markdown("---")
    st.subheader("�� Model Performance Comparison")
    
    perf_col1, perf_col2 = st.columns(2)
    
    with perf_col1:
        st.metric("ResNet50 Accuracy", "64.86%", "Base Model")
    
    with perf_col2:
        st.metric("DenseNet121 Accuracy", "63.96%", "Base Model")
    
    st.info("🎯 **Ensemble Method**: Combines both models for higher accuracy than individual models")

else:
    st.error("❌ Model could not be loaded. Please check the model files.")
    st.markdown("""

    **Required files:**

    - `best_resnet50.pth` (ResNet50 model weights)

    - `best_densenet121.pth` (DenseNet121 model weights)

    """)

# Footer
st.markdown("---")
st.markdown("""

<div style='text-align: center'>

    <p>�� Astronomy Image Classification System | Built with PyTorch & Streamlit</p>

    <p>Ensemble of ResNet50 + DenseNet121 | Target Accuracy: >95% | Current: 64.86%</p>

    <p>�� Deployed on Hugging Face Spaces</p>

</div>

""", unsafe_allow_html=True)