Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import joblib | |
| import pickle | |
| import jax | |
| import jax.numpy as jnp | |
| from flax import serialization | |
| from model import AQIPredictor | |
| # Load scaler and features | |
| scaler = joblib.load("aqi_scaler.pkl") | |
| features = pickle.load(open("aqi_features.pkl", "rb")) | |
| model = AQIPredictor(features=len(features)) | |
| dummy_input = jnp.ones((1, 24, len(features))) # Batch x Time x Features | |
| params = model.init(jax.random.PRNGKey(0), dummy_input, deterministic=True) | |
| with open("aqi_predictor.flax", "rb") as f: | |
| params = serialization.from_bytes(params, f.read()) | |
| def predict_fn(input_series): | |
| """ | |
| input_series: list of list of N features (24 x N) | |
| """ | |
| if len(input_series) != 24: | |
| return "Input must have 24 hourly records" | |
| scaled = scaler.transform(input_series) | |
| x = jnp.array(scaled).reshape(1, 24, len(features)) | |
| prediction = model.apply(params, x, deterministic=True) | |
| return float(prediction[0, 0]) | |
| # Input: 24xN table | |
| input_component = gr.Dataframe( | |
| headers=features, | |
| row_count=24, | |
| col_count=len(features), | |
| label="Enter past 24 hours of sensor data", | |
| type="numpy" | |
| ) | |
| iface = gr.Interface( | |
| fn=predict_fn, | |
| inputs=input_component, | |
| outputs="number", | |
| title="24-Hour AQI Forecast" | |
| ) | |
| iface.launch(share=True) | |