Spaces:
Sleeping
Sleeping
File size: 1,286 Bytes
51370cc ad21017 51370cc ad21017 1b227e0 ad21017 1b227e0 51370cc ad21017 1b227e0 ad21017 c38232c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | import gradio as gr
import joblib
import pickle
import jax
import jax.numpy as jnp
from flax import serialization
from model import AQIPredictor
# Load scaler and features
scaler = joblib.load("aqi_scaler.pkl")
features = pickle.load(open("aqi_features.pkl", "rb"))
model = AQIPredictor(features=len(features))
dummy_input = jnp.ones((1, 24, len(features))) # Batch x Time x Features
params = model.init(jax.random.PRNGKey(0), dummy_input, deterministic=True)
with open("aqi_predictor.flax", "rb") as f:
params = serialization.from_bytes(params, f.read())
def predict_fn(input_series):
"""
input_series: list of list of N features (24 x N)
"""
if len(input_series) != 24:
return "Input must have 24 hourly records"
scaled = scaler.transform(input_series)
x = jnp.array(scaled).reshape(1, 24, len(features))
prediction = model.apply(params, x, deterministic=True)
return float(prediction[0, 0])
# Input: 24xN table
input_component = gr.Dataframe(
headers=features,
row_count=24,
col_count=len(features),
label="Enter past 24 hours of sensor data",
type="numpy"
)
iface = gr.Interface(
fn=predict_fn,
inputs=input_component,
outputs="number",
title="24-Hour AQI Forecast"
)
iface.launch(share=True)
|