Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from PIL import Image | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| from model import AirPollutionMultimodalModel # Assuming you uploaded model.py | |
| # 1. Load Model & Weights | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| CLASS_NAMES = ['Good', 'Moderate', 'Unhealthy (Sensitive)', 'Unhealthy', 'Very Unhealthy', 'Severe'] | |
| def load_model(): | |
| model = AirPollutionMultimodalModel(num_classes=6, num_tabular_features=14) | |
| # Load from your Hub repo | |
| state_dict = torch.hub.load_state_dict_from_url( | |
| "https://huggingface.co/siamarefin2000/Air_pollution/resolve/main/best_base.pth", | |
| map_location=DEVICE | |
| ) | |
| model.load_state_dict(state_dict) | |
| model.to(DEVICE) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| # 2. Prediction Function | |
| def predict(img, year, aqi, pm25, pm10, o3, co, so2, no2, hour, month, day): | |
| # Image Preprocessing | |
| transform = A.Compose([A.Resize(224, 224), A.Normalize(), ToTensorV2()]) | |
| img_array = np.array(img) | |
| img_tensor = transform(image=img_array)['image'].unsqueeze(0).to(DEVICE) | |
| # Tabular Preprocessing (Matches your predict.py logic) | |
| tab_data = [year, aqi, pm25, pm10, o3, co, so2, no2] | |
| # Add cyclical features | |
| tab_data.extend([ | |
| np.sin(2 * np.pi * hour / 24), np.cos(2 * np.pi * hour / 24), | |
| np.sin(2 * np.pi * month / 12), np.cos(2 * np.pi * month / 12), | |
| np.sin(2 * np.pi * day / 31), np.cos(2 * np.pi * day / 31) | |
| ]) | |
| tab_tensor = torch.FloatTensor(tab_data).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| output = model(img_tensor, tab_tensor) | |
| probs = torch.nn.functional.softmax(output[0], dim=0) | |
| return {CLASS_NAMES[i]: float(probs[i]) for i in range(6)} | |
| # 3. Build UI | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Image(type="pil", label="Sky/Satellite Image"), | |
| gr.Number(label="Year", value=2026), | |
| gr.Slider(0, 500, label="Current AQI"), | |
| gr.Slider(0, 300, label="PM2.5"), | |
| gr.Slider(0, 300, label="PM10"), | |
| gr.Slider(0, 100, label="O3"), | |
| gr.Slider(0, 10, label="CO"), | |
| gr.Slider(0, 50, label="SO2"), | |
| gr.Slider(0, 100, label="NO2"), | |
| gr.Slider(0, 23, label="Hour of Day"), | |
| gr.Slider(1, 12, label="Month"), | |
| gr.Slider(1, 31, label="Day") | |
| ], | |
| outputs=gr.Label(num_top_classes=3), | |
| title="Multimodal Air Pollution Predictor", | |
| description="Upload an image and sensor data to see the AQI classification." | |
| ) | |
| demo.launch() |