File size: 6,490 Bytes
7992750 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | import streamlit as st
import torch
from PIL import Image
import numpy as np
from inference import get_inference_model
import json
# Page config
st.set_page_config(
page_title="🌌 Astronomy Image Classification",
page_icon="🌌",
layout="wide"
)
# Title
st.title("🌌 Astronomy Image Classification")
st.markdown("Classify astronomy images into 6 categories using ensemble of ResNet50 and DenseNet121 models")
# Sidebar
st.sidebar.title("📊 Model Info")
st.sidebar.markdown("""
**Models**: ResNet50 + DenseNet121 Ensemble
**ResNet50 Accuracy**: 64.86%
**DenseNet121 Accuracy**: 63.96%
**Ensemble**: Higher accuracy than individual models
**Classes**: 6 astronomy categories
**Input Size**: 224x224 pixels
""")
# Load model
@st.cache_resource
def load_model():
try:
return get_inference_model()
except Exception as e:
st.error(f"Error loading model: {e}")
return None
# Main interface
model = load_model()
if model is not None:
# Upload image
uploaded_file = st.file_uploader(
"Upload an astronomy image",
type=['jpg', 'jpeg', 'png'],
help="Upload an image of constellation, cosmos, galaxies, nebula, planets, or stars"
)
if uploaded_file is not None:
# Display image
col1, col2 = st.columns([1, 1])
with col1:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
with col2:
# Make prediction
with st.spinner("Analyzing image with ensemble models..."):
result = model.predict(image)
# Display results
st.subheader("🎯 Ensemble Prediction Results")
# Main prediction
predicted_class = result["predicted_class"]
confidence = result["confidence"]
# Color code based on confidence
if confidence > 0.8:
color = "��"
status = "High Confidence"
elif confidence > 0.6:
color = "🟡"
status = "Medium Confidence"
else:
color = "🔴"
status = "Low Confidence"
st.markdown(f"""
**{color} Predicted Class**: {predicted_class}
**Confidence**: {confidence:.3f}
**Status**: {status}
""")
# Progress bar
st.progress(confidence)
# Individual model results
if "individual_results" in result:
st.subheader("🔍 Individual Model Results")
individual_results = result["individual_results"]
for model_name, model_result in individual_results.items():
model_confidence = model_result["confidence"]
model_prediction = model_result["predicted_class"]
# Color code individual results
if model_confidence > 0.8:
model_color = "🟢"
elif model_confidence > 0.6:
model_color = "🟡"
else:
model_color = "🔴"
st.write(f"**{model_name}**: {model_color} {model_prediction} ({model_confidence:.3f})")
# All probabilities
st.subheader("�� All Class Probabilities")
probabilities = result["probabilities"]
# Create a more visual representation
for class_name, prob in sorted(probabilities.items(), key=lambda x: x[1], reverse=True):
# Create a bar chart for each probability
col_prob, col_bar = st.columns([2, 3])
with col_prob:
st.write(f"**{class_name}**")
with col_bar:
st.progress(prob)
st.write(f"{prob:.3f}")
# Sample images section
st.markdown("---")
st.subheader("📸 Sample Images")
# Create sample images with better descriptions
sample_cols = st.columns(3)
with sample_cols[0]:
st.markdown("**🌟 Constellation**")
st.info("Star patterns forming recognizable shapes like Orion, Big Dipper, etc.")
with sample_cols[1]:
st.markdown("**🌌 Galaxies**")
st.info("Spiral, elliptical, or irregular galaxies like Andromeda, Milky Way")
with sample_cols[2]:
st.markdown("**�� Nebula**")
st.info("Gas clouds and stellar nurseries like Orion Nebula, Eagle Nebula")
# Second row
sample_cols2 = st.columns(3)
with sample_cols2[0]:
st.markdown("**🪐 Planets**")
st.info("Solar system planets like Jupiter, Saturn, Mars, Earth")
with sample_cols2[1]:
st.markdown("**⭐ Stars**")
st.info("Individual stars, stellar objects, and stellar phenomena")
with sample_cols2[2]:
st.markdown("**🌠 Cosmos**")
st.info("General space scenes, cosmic phenomena, and deep space")
# Model comparison
st.markdown("---")
st.subheader("�� Model Performance Comparison")
perf_col1, perf_col2 = st.columns(2)
with perf_col1:
st.metric("ResNet50 Accuracy", "64.86%", "Base Model")
with perf_col2:
st.metric("DenseNet121 Accuracy", "63.96%", "Base Model")
st.info("🎯 **Ensemble Method**: Combines both models for higher accuracy than individual models")
else:
st.error("❌ Model could not be loaded. Please check the model files.")
st.markdown("""
**Required files:**
- `best_resnet50.pth` (ResNet50 model weights)
- `best_densenet121.pth` (DenseNet121 model weights)
""")
# Footer
st.markdown("---")
st.markdown("""
<div style='text-align: center'>
<p>�� Astronomy Image Classification System | Built with PyTorch & Streamlit</p>
<p>Ensemble of ResNet50 + DenseNet121 | Target Accuracy: >95% | Current: 64.86%</p>
<p>�� Deployed on Hugging Face Spaces</p>
</div>
""", unsafe_allow_html=True) |