Spaces:
Sleeping
Sleeping
Upload 6 files
Browse files- .gitattributes +1 -0
- app.py +41 -0
- aqi_features.pkl +3 -0
- aqi_predictor.flax +3 -0
- aqi_scaler.pkl +3 -0
- model.py +21 -0
- requirements.txt +6 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
aqi_predictor.flax filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
from fastapi import FastAPI, Request
|
| 3 |
+
import joblib
|
| 4 |
+
import pickle
|
| 5 |
+
import jax
|
| 6 |
+
import jax.numpy as jnp
|
| 7 |
+
from flax import serialization
|
| 8 |
+
from model import AQIPredictor
|
| 9 |
+
|
| 10 |
+
app = FastAPI()
|
| 11 |
+
|
| 12 |
+
# Load the scaler and feature list
|
| 13 |
+
scaler = joblib.load("scaler.pkl")
|
| 14 |
+
features = pickle.load(open("features.pkl", "rb"))
|
| 15 |
+
|
| 16 |
+
# Initialize model
|
| 17 |
+
model = AQIPredictor(features=len(features))
|
| 18 |
+
dummy_input = jnp.ones((1, len(features), 1)) # Shape for Conv1D
|
| 19 |
+
params = model.init(jax.random.PRNGKey(0), dummy_input, deterministic=True)
|
| 20 |
+
|
| 21 |
+
# Load trained model parameters
|
| 22 |
+
with open("predictor.flax", "rb") as f:
|
| 23 |
+
params = serialization.from_bytes(params, f.read())
|
| 24 |
+
|
| 25 |
+
@app.get("/")
|
| 26 |
+
def root():
|
| 27 |
+
return {"message": "AQI Predictor API is live."}
|
| 28 |
+
|
| 29 |
+
@app.post("/predict")
|
| 30 |
+
async def predict(request: Request):
|
| 31 |
+
try:
|
| 32 |
+
data = await request.json()
|
| 33 |
+
input_data = [data.get(f, 0.0) for f in features]
|
| 34 |
+
scaled = scaler.transform([input_data]) # (1, N)
|
| 35 |
+
reshaped = jnp.array(scaled).reshape(1, len(features), 1) # For Conv1D input
|
| 36 |
+
|
| 37 |
+
prediction = model.apply(params, reshaped, deterministic=True)
|
| 38 |
+
return {"predicted_aqi": float(prediction[0, 0])}
|
| 39 |
+
|
| 40 |
+
except Exception as e:
|
| 41 |
+
return {"error": str(e)}
|
aqi_features.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e74e100407793cc5a1a85de4bea0f51a4d3888eaa61ad65d63cb84b4fcc50323
|
| 3 |
+
size 78
|
aqi_predictor.flax
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:477157c1257397ae75189d693085047a484a0388a3b703f938701743b8147673
|
| 3 |
+
size 123048
|
aqi_scaler.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:228e6b2bb20ae057103bac75c9498db2c7eee7c58f36f326854a6812e82d66ea
|
| 3 |
+
size 655
|
model.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model.py
|
| 2 |
+
import flax.linen as nn
|
| 3 |
+
import jax.numpy as jnp
|
| 4 |
+
|
| 5 |
+
class AQIPredictor(nn.Module):
|
| 6 |
+
features: int
|
| 7 |
+
|
| 8 |
+
@nn.compact
|
| 9 |
+
def __call__(self, x, deterministic: bool):
|
| 10 |
+
x = nn.Conv(features=64, kernel_size=(3,))(x)
|
| 11 |
+
x = nn.relu(x)
|
| 12 |
+
x = nn.LayerNorm()(x)
|
| 13 |
+
x = nn.Conv(features=64, kernel_size=(3,))(x)
|
| 14 |
+
x = nn.relu(x)
|
| 15 |
+
x = nn.LayerNorm()(x)
|
| 16 |
+
x = jnp.mean(x, axis=1)
|
| 17 |
+
x = nn.Dense(128)(x)
|
| 18 |
+
x = nn.Dropout(0.1)(nn.silu(x), deterministic=deterministic)
|
| 19 |
+
x = nn.Dense(64)(x)
|
| 20 |
+
x = nn.silu(x)
|
| 21 |
+
return nn.Dense(1)(x)
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
jax
|
| 4 |
+
flax
|
| 5 |
+
scikit-learn
|
| 6 |
+
joblib
|