import streamlit as st import torch import torch.nn.functional as F import numpy as np import cv2 import matplotlib.pyplot as plt from PIL import Image from huggingface_hub import hf_hub_download import joblib from datetime import datetime from model_arch import MultiModalNet st.set_page_config(layout="wide", page_title="Multi-Modal AQI Analysis", initial_sidebar_state="expanded") st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_resources(): repo_id = "rocky250/aqi-multimodal" try: model_path = hf_hub_download(repo_id=repo_id, filename="model.pth") scaler_path = hf_hub_download(repo_id=repo_id, filename="scaler.joblib") checkpoint = torch.load(model_path, map_location='cpu') scaler = joblib.load(scaler_path) model = MultiModalNet(num_classes=6, num_tab_features=9) model.load_state_dict(checkpoint['model_state_dict']) model.eval() return model, scaler, checkpoint.get('class_map', { 0: 'Good', 1: 'Moderate', 2: 'Unhealthy for Sensitive Groups', 3: 'Unhealthy', 4: 'Very Unhealthy', 5: 'Hazardous' }) except Exception as e: st.error(f"Error loading resources: {e}") return None, None, None def get_cyclic_features(value, max_val): sin_val = np.sin(2 * np.pi * value / max_val) cos_val = np.cos(2 * np.pi * value / max_val) return sin_val, cos_val def get_color_and_advice(prediction): mapping = { 'Good': ('#2ecc71', "Air quality is satisfactory."), 'Moderate': ('#f1c40f', "Air quality is acceptable."), 'Unhealthy for Sensitive Groups': ('#e67e22', "Sensitive groups should reduce outdoor exertion."), 'Unhealthy': ('#e74c3c', "Avoid prolonged outdoor exertion."), 'Very Unhealthy': ('#8e44ad', "Health alert: avoid all outdoor activities."), 'Hazardous': ('#2c3e50', "Emergency conditions: remain indoors.") } return mapping.get(prediction, ('#95a5a6', "No advice available")) st.sidebar.header("Sensor Inputs") st.sidebar.markdown("Update these values to analyze real-time scenarios.") aqi_input = st.sidebar.number_input("AQI Value", min_value=0.0, value=120.0) pm25_input = st.sidebar.number_input("PM2.5", min_value=0.0, value=35.0) pm10_input = st.sidebar.number_input("PM10", min_value=0.0, value=75.0) st.sidebar.markdown("---") now = datetime.now() month = st.sidebar.slider("Month", 1, 12, now.month) day = st.sidebar.slider("Day", 1, 31, now.day) hour = st.sidebar.slider("Hour", 0, 23, now.hour) st.title("Multi-Modal AQI Classification") st.markdown("Integrates visual data (Sky/Horizon) with sensor metrics for precise air quality detection.") col_left, col_right = st.columns([1, 1.3], gap="large") with col_left: st.subheader("1. Visual Input") uploaded_file = st.file_uploader("Upload Image", type=['jpg', 'jpeg', 'png']) if uploaded_file: img_display = Image.open(uploaded_file) display_copy = img_display.copy() display_copy.thumbnail((400, 400)) st.image(display_copy, caption="Input Scene", use_container_width=True) else: st.info("Please upload an image to start.") with col_right: st.subheader("2. Analysis Results") if uploaded_file is not None: try: model, scaler, class_map = load_resources() if model: img = np.array(img_display.convert('RGB')) img = cv2.resize(img, (224, 224)) img_tensor = torch.tensor(img).permute(2, 0, 1).float().unsqueeze(0) / 255.0 mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) img_tensor = (img_tensor - mean) / std h_sin, h_cos = get_cyclic_features(hour, 24.0) m_sin, m_cos = get_cyclic_features(month, 12.0) d_sin, d_cos = get_cyclic_features(day, 31.0) raw_features = np.array([[aqi_input, pm25_input, pm10_input, h_sin, h_cos, m_sin, m_cos, d_sin, d_cos]]) scaled_features = scaler.transform(raw_features) tab_tensor = torch.tensor(scaled_features).float() with torch.no_grad(): logits, fusion_weights = model(img_tensor, tab_tensor) probs = F.softmax(logits, dim=1) conf, pred_idx = torch.max(probs, dim=1) if isinstance(list(class_map.keys())[0], int): prediction = class_map[pred_idx.item()] else: inv_map = {v: k for k, v in class_map.items()} prediction = inv_map[pred_idx.item()] color, advice = get_color_and_advice(prediction) st.markdown(f"""
Confidence: {conf.item()*100:.1f}%
"{advice}"