nikethanreddy commited on
Commit
865db26
·
verified ·
1 Parent(s): 9ebeae7

Upload 6 files

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. app.py +41 -0
  3. aqi_features.pkl +3 -0
  4. aqi_predictor.flax +3 -0
  5. aqi_scaler.pkl +3 -0
  6. model.py +21 -0
  7. requirements.txt +6 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ aqi_predictor.flax filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ from fastapi import FastAPI, Request
3
+ import joblib
4
+ import pickle
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from flax import serialization
8
+ from model import AQIPredictor
9
+
10
+ app = FastAPI()
11
+
12
+ # Load the scaler and feature list
13
+ scaler = joblib.load("scaler.pkl")
14
+ features = pickle.load(open("features.pkl", "rb"))
15
+
16
+ # Initialize model
17
+ model = AQIPredictor(features=len(features))
18
+ dummy_input = jnp.ones((1, len(features), 1)) # Shape for Conv1D
19
+ params = model.init(jax.random.PRNGKey(0), dummy_input, deterministic=True)
20
+
21
+ # Load trained model parameters
22
+ with open("predictor.flax", "rb") as f:
23
+ params = serialization.from_bytes(params, f.read())
24
+
25
+ @app.get("/")
26
+ def root():
27
+ return {"message": "AQI Predictor API is live."}
28
+
29
+ @app.post("/predict")
30
+ async def predict(request: Request):
31
+ try:
32
+ data = await request.json()
33
+ input_data = [data.get(f, 0.0) for f in features]
34
+ scaled = scaler.transform([input_data]) # (1, N)
35
+ reshaped = jnp.array(scaled).reshape(1, len(features), 1) # For Conv1D input
36
+
37
+ prediction = model.apply(params, reshaped, deterministic=True)
38
+ return {"predicted_aqi": float(prediction[0, 0])}
39
+
40
+ except Exception as e:
41
+ return {"error": str(e)}
aqi_features.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e74e100407793cc5a1a85de4bea0f51a4d3888eaa61ad65d63cb84b4fcc50323
3
+ size 78
aqi_predictor.flax ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:477157c1257397ae75189d693085047a484a0388a3b703f938701743b8147673
3
+ size 123048
aqi_scaler.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:228e6b2bb20ae057103bac75c9498db2c7eee7c58f36f326854a6812e82d66ea
3
+ size 655
model.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+ import flax.linen as nn
3
+ import jax.numpy as jnp
4
+
5
+ class AQIPredictor(nn.Module):
6
+ features: int
7
+
8
+ @nn.compact
9
+ def __call__(self, x, deterministic: bool):
10
+ x = nn.Conv(features=64, kernel_size=(3,))(x)
11
+ x = nn.relu(x)
12
+ x = nn.LayerNorm()(x)
13
+ x = nn.Conv(features=64, kernel_size=(3,))(x)
14
+ x = nn.relu(x)
15
+ x = nn.LayerNorm()(x)
16
+ x = jnp.mean(x, axis=1)
17
+ x = nn.Dense(128)(x)
18
+ x = nn.Dropout(0.1)(nn.silu(x), deterministic=deterministic)
19
+ x = nn.Dense(64)(x)
20
+ x = nn.silu(x)
21
+ return nn.Dense(1)(x)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ jax
4
+ flax
5
+ scikit-learn
6
+ joblib