AQIMultiModal / src /streamlit_app.py
rocky250's picture
Update src/streamlit_app.py
ac9a668 verified
import streamlit as st
import torch
import torch.nn.functional as F
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
from huggingface_hub import hf_hub_download
import joblib
from datetime import datetime
from model_arch import MultiModalNet
st.set_page_config(layout="wide", page_title="Multi-Modal AQI Analysis", initial_sidebar_state="expanded")
st.markdown("""
<style>
.result-card {
background-color: #ffffff;
padding: 20px;
border-radius: 12px;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
text-align: center;
margin-bottom: 20px;
border: 2px solid #e0e0e0;
}
.status-text {
font-size: 24px;
font-weight: bold;
margin: 10px 0;
}
.confidence-text {
font-size: 14px;
color: #666;
}
/* Metric styling */
div[data-testid="stMetric"] {
background-color: #f8f9fa;
padding: 10px;
border-radius: 8px;
border: 1px solid #dee2e6;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_resources():
repo_id = "rocky250/aqi-multimodal"
try:
model_path = hf_hub_download(repo_id=repo_id, filename="model.pth")
scaler_path = hf_hub_download(repo_id=repo_id, filename="scaler.joblib")
checkpoint = torch.load(model_path, map_location='cpu')
scaler = joblib.load(scaler_path)
model = MultiModalNet(num_classes=6, num_tab_features=9)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model, scaler, checkpoint.get('class_map', {
0: 'Good', 1: 'Moderate', 2: 'Unhealthy for Sensitive Groups',
3: 'Unhealthy', 4: 'Very Unhealthy', 5: 'Hazardous'
})
except Exception as e:
st.error(f"Error loading resources: {e}")
return None, None, None
def get_cyclic_features(value, max_val):
sin_val = np.sin(2 * np.pi * value / max_val)
cos_val = np.cos(2 * np.pi * value / max_val)
return sin_val, cos_val
def get_color_and_advice(prediction):
mapping = {
'Good': ('#2ecc71', "Air quality is satisfactory."),
'Moderate': ('#f1c40f', "Air quality is acceptable."),
'Unhealthy for Sensitive Groups': ('#e67e22', "Sensitive groups should reduce outdoor exertion."),
'Unhealthy': ('#e74c3c', "Avoid prolonged outdoor exertion."),
'Very Unhealthy': ('#8e44ad', "Health alert: avoid all outdoor activities."),
'Hazardous': ('#2c3e50', "Emergency conditions: remain indoors.")
}
return mapping.get(prediction, ('#95a5a6', "No advice available"))
st.sidebar.header("Sensor Inputs")
st.sidebar.markdown("Update these values to analyze real-time scenarios.")
aqi_input = st.sidebar.number_input("AQI Value", min_value=0.0, value=120.0)
pm25_input = st.sidebar.number_input("PM2.5", min_value=0.0, value=35.0)
pm10_input = st.sidebar.number_input("PM10", min_value=0.0, value=75.0)
st.sidebar.markdown("---")
now = datetime.now()
month = st.sidebar.slider("Month", 1, 12, now.month)
day = st.sidebar.slider("Day", 1, 31, now.day)
hour = st.sidebar.slider("Hour", 0, 23, now.hour)
st.title("Multi-Modal AQI Classification")
st.markdown("Integrates visual data (Sky/Horizon) with sensor metrics for precise air quality detection.")
col_left, col_right = st.columns([1, 1.3], gap="large")
with col_left:
st.subheader("1. Visual Input")
uploaded_file = st.file_uploader("Upload Image", type=['jpg', 'jpeg', 'png'])
if uploaded_file:
img_display = Image.open(uploaded_file)
display_copy = img_display.copy()
display_copy.thumbnail((400, 400))
st.image(display_copy, caption="Input Scene", use_container_width=True)
else:
st.info("Please upload an image to start.")
with col_right:
st.subheader("2. Analysis Results")
if uploaded_file is not None:
try:
model, scaler, class_map = load_resources()
if model:
img = np.array(img_display.convert('RGB'))
img = cv2.resize(img, (224, 224))
img_tensor = torch.tensor(img).permute(2, 0, 1).float().unsqueeze(0) / 255.0
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
img_tensor = (img_tensor - mean) / std
h_sin, h_cos = get_cyclic_features(hour, 24.0)
m_sin, m_cos = get_cyclic_features(month, 12.0)
d_sin, d_cos = get_cyclic_features(day, 31.0)
raw_features = np.array([[aqi_input, pm25_input, pm10_input, h_sin, h_cos, m_sin, m_cos, d_sin, d_cos]])
scaled_features = scaler.transform(raw_features)
tab_tensor = torch.tensor(scaled_features).float()
with torch.no_grad():
logits, fusion_weights = model(img_tensor, tab_tensor)
probs = F.softmax(logits, dim=1)
conf, pred_idx = torch.max(probs, dim=1)
if isinstance(list(class_map.keys())[0], int):
prediction = class_map[pred_idx.item()]
else:
inv_map = {v: k for k, v in class_map.items()}
prediction = inv_map[pred_idx.item()]
color, advice = get_color_and_advice(prediction)
st.markdown(f"""
<div class="result-card" style="border-color: {color};">
<h3 style="color: {color}; margin-bottom:0;">{prediction}</h3>
<p class="confidence-text">Confidence: {conf.item()*100:.1f}%</p>
<hr style="margin: 10px 0;">
<p style="font-style: italic;">"{advice}"</p>
</div>
""", unsafe_allow_html=True)
tab1, tab2 = st.tabs(["Contribution", "Probabilities"])
with tab1:
with st.container(height=300):
fw = fusion_weights[0].cpu().numpy()
fig, ax = plt.subplots(figsize=(13.5, 2.4))
labels = ['Tabular', 'Image']
values = [fw[1], fw[0]]
ax.barh(labels, values)
ax.set_xlim(0, 1)
ax.set_xlabel("Fusion Weight")
for i, v in enumerate(values):
ax.text(v + 0.02, i, f"{v:.2f}", va='center', fontweight='bold')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
st.pyplot(fig, clear_figure=True)
with tab2:
with st.container(height=300):
p_vals = probs[0].cpu().numpy()
if isinstance(list(class_map.keys())[0], int):
p_labels = [class_map[i] for i in range(len(p_vals))]
else:
inv_map = {v: k for k, v in class_map.items()}
p_labels = [inv_map[i] for i in range(len(p_vals))]
fig2, ax2 = plt.subplots(figsize=(11.5, 1.4))
ax2.bar(p_labels, p_vals)
ax2.set_ylim(0, 1)
ax2.set_ylabel("Probability")
ax2.tick_params(axis='x', labelrotation=35)
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
st.pyplot(fig2, clear_figure=True)
except Exception as e:
st.error(f"Analysis failed: {e}")