Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| import joblib | |
| from datetime import datetime | |
| from model_arch import MultiModalNet | |
| st.set_page_config(layout="wide", page_title="Multi-Modal AQI Analysis", initial_sidebar_state="expanded") | |
| st.markdown(""" | |
| <style> | |
| .result-card { | |
| background-color: #ffffff; | |
| padding: 20px; | |
| border-radius: 12px; | |
| box-shadow: 0 4px 6px rgba(0,0,0,0.1); | |
| text-align: center; | |
| margin-bottom: 20px; | |
| border: 2px solid #e0e0e0; | |
| } | |
| .status-text { | |
| font-size: 24px; | |
| font-weight: bold; | |
| margin: 10px 0; | |
| } | |
| .confidence-text { | |
| font-size: 14px; | |
| color: #666; | |
| } | |
| /* Metric styling */ | |
| div[data-testid="stMetric"] { | |
| background-color: #f8f9fa; | |
| padding: 10px; | |
| border-radius: 8px; | |
| border: 1px solid #dee2e6; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| def load_resources(): | |
| repo_id = "rocky250/aqi-multimodal" | |
| try: | |
| model_path = hf_hub_download(repo_id=repo_id, filename="model.pth") | |
| scaler_path = hf_hub_download(repo_id=repo_id, filename="scaler.joblib") | |
| checkpoint = torch.load(model_path, map_location='cpu') | |
| scaler = joblib.load(scaler_path) | |
| model = MultiModalNet(num_classes=6, num_tab_features=9) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| return model, scaler, checkpoint.get('class_map', { | |
| 0: 'Good', 1: 'Moderate', 2: 'Unhealthy for Sensitive Groups', | |
| 3: 'Unhealthy', 4: 'Very Unhealthy', 5: 'Hazardous' | |
| }) | |
| except Exception as e: | |
| st.error(f"Error loading resources: {e}") | |
| return None, None, None | |
| def get_cyclic_features(value, max_val): | |
| sin_val = np.sin(2 * np.pi * value / max_val) | |
| cos_val = np.cos(2 * np.pi * value / max_val) | |
| return sin_val, cos_val | |
| def get_color_and_advice(prediction): | |
| mapping = { | |
| 'Good': ('#2ecc71', "Air quality is satisfactory."), | |
| 'Moderate': ('#f1c40f', "Air quality is acceptable."), | |
| 'Unhealthy for Sensitive Groups': ('#e67e22', "Sensitive groups should reduce outdoor exertion."), | |
| 'Unhealthy': ('#e74c3c', "Avoid prolonged outdoor exertion."), | |
| 'Very Unhealthy': ('#8e44ad', "Health alert: avoid all outdoor activities."), | |
| 'Hazardous': ('#2c3e50', "Emergency conditions: remain indoors.") | |
| } | |
| return mapping.get(prediction, ('#95a5a6', "No advice available")) | |
| st.sidebar.header("Sensor Inputs") | |
| st.sidebar.markdown("Update these values to analyze real-time scenarios.") | |
| aqi_input = st.sidebar.number_input("AQI Value", min_value=0.0, value=120.0) | |
| pm25_input = st.sidebar.number_input("PM2.5", min_value=0.0, value=35.0) | |
| pm10_input = st.sidebar.number_input("PM10", min_value=0.0, value=75.0) | |
| st.sidebar.markdown("---") | |
| now = datetime.now() | |
| month = st.sidebar.slider("Month", 1, 12, now.month) | |
| day = st.sidebar.slider("Day", 1, 31, now.day) | |
| hour = st.sidebar.slider("Hour", 0, 23, now.hour) | |
| st.title("Multi-Modal AQI Classification") | |
| st.markdown("Integrates visual data (Sky/Horizon) with sensor metrics for precise air quality detection.") | |
| col_left, col_right = st.columns([1, 1.3], gap="large") | |
| with col_left: | |
| st.subheader("1. Visual Input") | |
| uploaded_file = st.file_uploader("Upload Image", type=['jpg', 'jpeg', 'png']) | |
| if uploaded_file: | |
| img_display = Image.open(uploaded_file) | |
| display_copy = img_display.copy() | |
| display_copy.thumbnail((400, 400)) | |
| st.image(display_copy, caption="Input Scene", use_container_width=True) | |
| else: | |
| st.info("Please upload an image to start.") | |
| with col_right: | |
| st.subheader("2. Analysis Results") | |
| if uploaded_file is not None: | |
| try: | |
| model, scaler, class_map = load_resources() | |
| if model: | |
| img = np.array(img_display.convert('RGB')) | |
| img = cv2.resize(img, (224, 224)) | |
| img_tensor = torch.tensor(img).permute(2, 0, 1).float().unsqueeze(0) / 255.0 | |
| mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) | |
| std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) | |
| img_tensor = (img_tensor - mean) / std | |
| h_sin, h_cos = get_cyclic_features(hour, 24.0) | |
| m_sin, m_cos = get_cyclic_features(month, 12.0) | |
| d_sin, d_cos = get_cyclic_features(day, 31.0) | |
| raw_features = np.array([[aqi_input, pm25_input, pm10_input, h_sin, h_cos, m_sin, m_cos, d_sin, d_cos]]) | |
| scaled_features = scaler.transform(raw_features) | |
| tab_tensor = torch.tensor(scaled_features).float() | |
| with torch.no_grad(): | |
| logits, fusion_weights = model(img_tensor, tab_tensor) | |
| probs = F.softmax(logits, dim=1) | |
| conf, pred_idx = torch.max(probs, dim=1) | |
| if isinstance(list(class_map.keys())[0], int): | |
| prediction = class_map[pred_idx.item()] | |
| else: | |
| inv_map = {v: k for k, v in class_map.items()} | |
| prediction = inv_map[pred_idx.item()] | |
| color, advice = get_color_and_advice(prediction) | |
| st.markdown(f""" | |
| <div class="result-card" style="border-color: {color};"> | |
| <h3 style="color: {color}; margin-bottom:0;">{prediction}</h3> | |
| <p class="confidence-text">Confidence: {conf.item()*100:.1f}%</p> | |
| <hr style="margin: 10px 0;"> | |
| <p style="font-style: italic;">"{advice}"</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| tab1, tab2 = st.tabs(["Contribution", "Probabilities"]) | |
| with tab1: | |
| with st.container(height=300): | |
| fw = fusion_weights[0].cpu().numpy() | |
| fig, ax = plt.subplots(figsize=(13.5, 2.4)) | |
| labels = ['Tabular', 'Image'] | |
| values = [fw[1], fw[0]] | |
| ax.barh(labels, values) | |
| ax.set_xlim(0, 1) | |
| ax.set_xlabel("Fusion Weight") | |
| for i, v in enumerate(values): | |
| ax.text(v + 0.02, i, f"{v:.2f}", va='center', fontweight='bold') | |
| ax.spines['top'].set_visible(False) | |
| ax.spines['right'].set_visible(False) | |
| st.pyplot(fig, clear_figure=True) | |
| with tab2: | |
| with st.container(height=300): | |
| p_vals = probs[0].cpu().numpy() | |
| if isinstance(list(class_map.keys())[0], int): | |
| p_labels = [class_map[i] for i in range(len(p_vals))] | |
| else: | |
| inv_map = {v: k for k, v in class_map.items()} | |
| p_labels = [inv_map[i] for i in range(len(p_vals))] | |
| fig2, ax2 = plt.subplots(figsize=(11.5, 1.4)) | |
| ax2.bar(p_labels, p_vals) | |
| ax2.set_ylim(0, 1) | |
| ax2.set_ylabel("Probability") | |
| ax2.tick_params(axis='x', labelrotation=35) | |
| ax2.spines['top'].set_visible(False) | |
| ax2.spines['right'].set_visible(False) | |
| st.pyplot(fig2, clear_figure=True) | |
| except Exception as e: | |
| st.error(f"Analysis failed: {e}") |