rishab1090 commited on
Commit
5dffcf4
·
verified ·
1 Parent(s): 6d303e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, UploadFile, File
2
  from fastapi.responses import JSONResponse
3
  from PIL import Image
4
  import numpy as np
@@ -7,22 +7,26 @@ import io
7
 
8
  app = FastAPI()
9
 
10
- # Load the model
11
  model = tf.keras.models.load_model("2.keras")
12
  CLASS_NAMES = ['Fungi', 'Healthy', 'Nematode', 'Pest', 'Phytopthora', 'Virus']
13
 
 
 
 
14
  @app.post("/predict")
15
- async def predict(file: UploadFile = File(...)):
 
 
 
16
  try:
17
  contents = await file.read()
18
 
19
- # Load and process image
20
  image = Image.open(io.BytesIO(contents)).convert("RGB")
21
  image = image.resize((224, 224))
22
-
23
- # Don't normalize — match tf.data image_dataset_from_directory
24
  img_array = np.array(image).astype("float32")
25
- img_array = np.expand_dims(img_array, axis=0) # (1, 224, 224, 3)
26
 
27
  # Predict
28
  prediction = model.predict(img_array)
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Header
2
  from fastapi.responses import JSONResponse
3
  from PIL import Image
4
  import numpy as np
 
7
 
8
  app = FastAPI()
9
 
10
+ # Load your model
11
  model = tf.keras.models.load_model("2.keras")
12
  CLASS_NAMES = ['Fungi', 'Healthy', 'Nematode', 'Pest', 'Phytopthora', 'Virus']
13
 
14
+ # Define your API key (keep it secret in prod)
15
+ API_KEY = "mysecretkey"
16
+
17
  @app.post("/predict")
18
+ async def predict(file: UploadFile = File(...), x_api_key: str = Header(None)):
19
+ if x_api_key != API_KEY:
20
+ raise HTTPException(status_code=401, detail="Invalid or missing API Key")
21
+
22
  try:
23
  contents = await file.read()
24
 
25
+ # Process the image
26
  image = Image.open(io.BytesIO(contents)).convert("RGB")
27
  image = image.resize((224, 224))
 
 
28
  img_array = np.array(image).astype("float32")
29
+ img_array = np.expand_dims(img_array, axis=0)
30
 
31
  # Predict
32
  prediction = model.predict(img_array)