1MR commited on
Commit
0ea3bba
·
verified ·
1 Parent(s): f03fa46

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import JSONResponse
3
+ import tensorflow as tf
4
+ import numpy as np
5
+ import shutil
6
+
7
+ # Initialize FastAPI app
8
+ app = FastAPI()
9
+
10
+ # Class labels
11
+ class_labels = {
12
+ 0: 'Baked Potato', 1: 'Burger', 2: 'Crispy Chicken', 3: 'Donut', 4: 'Fries',
13
+ 5: 'Hot Dog', 6: 'Jalapeno', 7: 'Kiwi', 8: 'Lemon', 9: 'Lettuce',
14
+ 10: 'Mango', 11: 'Onion', 12: 'Orange', 13: 'Pizza', 14: 'Taquito',
15
+ 15: 'Apple', 16: 'Banana', 17: 'Beetroot', 18: 'Bell Pepper', 19: 'Bread',
16
+ 20: 'Cabbage', 21: 'Carrot', 22: 'Cauliflower', 23: 'Cheese',
17
+ 24: 'Chilli Pepper', 25: 'Corn', 26: 'Crab', 27: 'Cucumber',
18
+ 28: 'Eggplant', 29: 'Eggs', 30: 'Garlic', 31: 'Ginger', 32: 'Grapes',
19
+ 33: 'Milk', 34: 'Salmon', 35: 'Yogurt'
20
+ }
21
+
22
+ # Load the trained model
23
+ model = tf.keras.models.load_model("model_unfreezeNewCorrectpredict.keras")
24
+
25
+ # Image preprocessing function
26
+ def load_and_prep_image(file_path, img_shape=224):
27
+ img = tf.io.read_file(file_path)
28
+ img = tf.image.decode_image(img, channels=3)
29
+ img = tf.image.resize(img, size=[img_shape, img_shape])
30
+ img = tf.expand_dims(img, axis=0)
31
+ return img
32
+
33
+ # Predict label function
34
+ def predict_label(model, image_path, class_names):
35
+ img = load_and_prep_image(image_path, img_shape=224)
36
+ pred = model.predict(img)
37
+ pred_class_index = np.argmax(pred, axis=1)[0]
38
+ pred_class_name = class_names[pred_class_index]
39
+ return pred_class_name
40
+
41
+ # API endpoint for prediction
42
+ @app.post("/predict")
43
+ async def predict_image(file: UploadFile = File(...)):
44
+ try:
45
+ # Save the uploaded file
46
+ file_location = f"./temp_{file.filename}"
47
+ with open(file_location, "wb") as f:
48
+ shutil.copyfileobj(file.file, f)
49
+
50
+ # Predict the label
51
+ prediction = predict_label(model, file_location, class_labels)
52
+
53
+ # Remove the temporary file
54
+ os.remove(file_location)
55
+
56
+ return {"predicted_label": prediction}
57
+ except Exception as e:
58
+ return JSONResponse(
59
+ status_code=500,
60
+ content={"error": f"An error occurred: {str(e)}"}
61
+ )