air_pollution2 / app.py
siamarefin2000's picture
create app.py
a087c01 verified
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from model import AirPollutionMultimodalModel # Assuming you uploaded model.py
# 1. Load Model & Weights
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CLASS_NAMES = ['Good', 'Moderate', 'Unhealthy (Sensitive)', 'Unhealthy', 'Very Unhealthy', 'Severe']
def load_model():
model = AirPollutionMultimodalModel(num_classes=6, num_tabular_features=14)
# Load from your Hub repo
state_dict = torch.hub.load_state_dict_from_url(
"https://huggingface.co/siamarefin2000/Air_pollution/resolve/main/best_base.pth",
map_location=DEVICE
)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
return model
model = load_model()
# 2. Prediction Function
def predict(img, year, aqi, pm25, pm10, o3, co, so2, no2, hour, month, day):
# Image Preprocessing
transform = A.Compose([A.Resize(224, 224), A.Normalize(), ToTensorV2()])
img_array = np.array(img)
img_tensor = transform(image=img_array)['image'].unsqueeze(0).to(DEVICE)
# Tabular Preprocessing (Matches your predict.py logic)
tab_data = [year, aqi, pm25, pm10, o3, co, so2, no2]
# Add cyclical features
tab_data.extend([
np.sin(2 * np.pi * hour / 24), np.cos(2 * np.pi * hour / 24),
np.sin(2 * np.pi * month / 12), np.cos(2 * np.pi * month / 12),
np.sin(2 * np.pi * day / 31), np.cos(2 * np.pi * day / 31)
])
tab_tensor = torch.FloatTensor(tab_data).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(img_tensor, tab_tensor)
probs = torch.nn.functional.softmax(output[0], dim=0)
return {CLASS_NAMES[i]: float(probs[i]) for i in range(6)}
# 3. Build UI
demo = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Sky/Satellite Image"),
gr.Number(label="Year", value=2026),
gr.Slider(0, 500, label="Current AQI"),
gr.Slider(0, 300, label="PM2.5"),
gr.Slider(0, 300, label="PM10"),
gr.Slider(0, 100, label="O3"),
gr.Slider(0, 10, label="CO"),
gr.Slider(0, 50, label="SO2"),
gr.Slider(0, 100, label="NO2"),
gr.Slider(0, 23, label="Hour of Day"),
gr.Slider(1, 12, label="Month"),
gr.Slider(1, 31, label="Day")
],
outputs=gr.Label(num_top_classes=3),
title="Multimodal Air Pollution Predictor",
description="Upload an image and sensor data to see the AQI classification."
)
demo.launch()