File size: 1,286 Bytes
51370cc
ad21017
 
 
 
 
 
 
51370cc
ad21017
 
 
 
1b227e0
ad21017
 
 
 
 
1b227e0
 
 
 
 
 
 
 
 
 
51370cc
ad21017
1b227e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad21017
c38232c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import gradio as gr
import joblib
import pickle
import jax
import jax.numpy as jnp
from flax import serialization
from model import AQIPredictor

# Load scaler and features
scaler = joblib.load("aqi_scaler.pkl")
features = pickle.load(open("aqi_features.pkl", "rb"))

model = AQIPredictor(features=len(features))
dummy_input = jnp.ones((1, 24, len(features)))  # Batch x Time x Features
params = model.init(jax.random.PRNGKey(0), dummy_input, deterministic=True)

with open("aqi_predictor.flax", "rb") as f:
    params = serialization.from_bytes(params, f.read())

def predict_fn(input_series):
    """
    input_series: list of list of N features (24 x N)
    """
    if len(input_series) != 24:
        return "Input must have 24 hourly records"

    scaled = scaler.transform(input_series)
    x = jnp.array(scaled).reshape(1, 24, len(features))
    prediction = model.apply(params, x, deterministic=True)
    return float(prediction[0, 0])

# Input: 24xN table
input_component = gr.Dataframe(
    headers=features,
    row_count=24,
    col_count=len(features),
    label="Enter past 24 hours of sensor data",
    type="numpy"
)

iface = gr.Interface(
    fn=predict_fn,
    inputs=input_component,
    outputs="number",
    title="24-Hour AQI Forecast"
)

iface.launch(share=True)