project2 / app.py
nikethanreddy's picture
Update app.py
c38232c verified
import gradio as gr
import joblib
import pickle
import jax
import jax.numpy as jnp
from flax import serialization
from model import AQIPredictor
# Load scaler and features
scaler = joblib.load("aqi_scaler.pkl")
features = pickle.load(open("aqi_features.pkl", "rb"))
model = AQIPredictor(features=len(features))
dummy_input = jnp.ones((1, 24, len(features))) # Batch x Time x Features
params = model.init(jax.random.PRNGKey(0), dummy_input, deterministic=True)
with open("aqi_predictor.flax", "rb") as f:
params = serialization.from_bytes(params, f.read())
def predict_fn(input_series):
"""
input_series: list of list of N features (24 x N)
"""
if len(input_series) != 24:
return "Input must have 24 hourly records"
scaled = scaler.transform(input_series)
x = jnp.array(scaled).reshape(1, 24, len(features))
prediction = model.apply(params, x, deterministic=True)
return float(prediction[0, 0])
# Input: 24xN table
input_component = gr.Dataframe(
headers=features,
row_count=24,
col_count=len(features),
label="Enter past 24 hours of sensor data",
type="numpy"
)
iface = gr.Interface(
fn=predict_fn,
inputs=input_component,
outputs="number",
title="24-Hour AQI Forecast"
)
iface.launch(share=True)