import gradio as gr import joblib import pickle import jax import jax.numpy as jnp from flax import serialization from model import AQIPredictor # Load scaler and features scaler = joblib.load("aqi_scaler.pkl") features = pickle.load(open("aqi_features.pkl", "rb")) model = AQIPredictor(features=len(features)) dummy_input = jnp.ones((1, 24, len(features))) # Batch x Time x Features params = model.init(jax.random.PRNGKey(0), dummy_input, deterministic=True) with open("aqi_predictor.flax", "rb") as f: params = serialization.from_bytes(params, f.read()) def predict_fn(input_series): """ input_series: list of list of N features (24 x N) """ if len(input_series) != 24: return "Input must have 24 hourly records" scaled = scaler.transform(input_series) x = jnp.array(scaled).reshape(1, 24, len(features)) prediction = model.apply(params, x, deterministic=True) return float(prediction[0, 0]) # Input: 24xN table input_component = gr.Dataframe( headers=features, row_count=24, col_count=len(features), label="Enter past 24 hours of sensor data", type="numpy" ) iface = gr.Interface( fn=predict_fn, inputs=input_component, outputs="number", title="24-Hour AQI Forecast" ) iface.launch(share=True)