namanraj commited on
Commit
9429213
·
verified ·
1 Parent(s): 5725093

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +30 -21
app/main.py CHANGED
@@ -1,21 +1,30 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- import numpy as np
4
- from tensorflow import keras
5
-
6
- app = FastAPI()
7
-
8
- # Load model
9
- model = keras.models.load_model("best_model.h5")
10
-
11
- # Input schema
12
- class InputData(BaseModel):
13
- pixels: list # flattened 28x28 = 784 values
14
-
15
- @app.post("/predict")
16
- def predict(data: InputData):
17
- # Convert list → NumPy
18
- X = np.array(data.pixels).reshape(1, 28, 28, 1) / 255.0 # normalize if trained that way
19
- y_pred = model.predict(X)
20
- predicted_class = int(np.argmax(y_pred, axis=1)[0])
21
- return {"prediction": predicted_class}
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import numpy as np
4
+ from tensorflow import keras
5
+ import os
6
+
7
+ app = FastAPI()
8
+
9
+ # Debug: print current directory and files
10
+ print("Current working directory:", os.getcwd())
11
+ print("Files in app folder:", os.listdir("."))
12
+
13
+ # Load model (must match filename)
14
+ model_path = "best_model.h5"
15
+ model = keras.models.load_model(model_path)
16
+
17
+ # Input schema
18
+ class InputData(BaseModel):
19
+ pixels: list # 784 flattened pixels (28x28)
20
+
21
+ @app.get("/")
22
+ def root():
23
+ return {"message": "MNIST API running"}
24
+
25
+ @app.post("/predict")
26
+ def predict(data: InputData):
27
+ X = np.array(data.pixels).reshape(1, 28, 28, 1) / 255.0
28
+ y_pred = model.predict(X)
29
+ predicted_class = int(np.argmax(y_pred, axis=1)[0])
30
+ return {"prediction": predicted_class}