sharktide commited on
Commit
b9068db
·
verified ·
1 Parent(s): 0fcf4c1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Request
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from PIL import Image
5
+ import cv2
6
+ from fastapi.responses import JSONResponse
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ import logging
9
+ import tensorflowtools as tft
10
+
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ tft.hftools.download_model_from_huggingface('sharktide', 'fruitbot0', 'tf_model.keras')
16
+ tft.hftools.download_model_from_huggingface('sharktide', 'fruitbot1', 'tf_model.keras')
17
+
18
+ fruitbot0 = tft.kerastools.load_from_hf_cache("sharktide", "fruitbot0", "tf_model.keras")
19
+ fruitbot1 = tft.kerastools.load_from_hf_cache("sharktide", "fruitbot1", "tf_model.keras")
20
+
21
+ FRUITBOT_CLASSES = ['Apple 10', 'Apple 11', 'Apple 12', 'Apple 13', 'Apple 14', 'Apple 17', 'Apple 18', 'Apple 19',
22
+ 'Apple 5', 'Apple 7', 'Apple 8', 'Apple 9', 'Apple Core 1', 'Apple Red Yellow 2', 'Apple worm 1',
23
+ 'Banana 3', 'Beans 1', 'Blackberrie 1', 'Blackberrie 2', 'Blackberrie half rippen 1',
24
+ 'Blackberrie not rippen 1', 'Cabbage red 1', 'Cactus fruit green 1', 'Cactus fruit red 1', 'Caju seed 1',
25
+ 'Cherimoya 1', 'Cherry Wax not rippen 1', 'Cucumber 10', 'Cucumber 9', 'Gooseberry 1', 'Pistachio 1',
26
+ 'Quince 2', 'Quince 3', 'Quince 4', 'Tomato 1', 'Tomato 5', 'apple_6', 'apple_braeburn_1',
27
+ 'apple_crimson_snow_1', 'apple_golden_1', 'apple_golden_2', 'apple_golden_3', 'apple_granny_smith_1',
28
+ 'apple_hit_1', 'apple_pink_lady_1', 'apple_red_1', 'apple_red_2', 'apple_red_3', 'apple_red_delicios_1',
29
+ 'apple_red_yellow_1', 'apple_rotten_1', 'cabbage_white_1', 'carrot_1', 'cucumber_1', 'cucumber_3',
30
+ 'eggplant_long_1', 'pear_1', 'pear_3', 'zucchini_1', 'zucchini_dark_1']
31
+
32
+ # Create FastAPI app
33
+ app = FastAPI()
34
+
35
+ app.add_middleware(
36
+ CORSMiddleware,
37
+ allow_origins=["*"],
38
+ allow_credentials=True,
39
+ allow_methods=["*"],
40
+ allow_headers=["*"],
41
+ )
42
+
43
+ # Preprocess the image (resize, reshape without normalization)
44
+ def preprocess_image(image_file, model):
45
+ try:
46
+ # Load image using PIL
47
+ image = Image.open(image_file)
48
+
49
+ # Convert image to numpy array
50
+ image = np.array(image)
51
+
52
+ if model == "fruitbot0":
53
+ image = cv2.resize(image, (240, 240))
54
+ image = image.reshape(-1, 240, 240, 3)
55
+ elif model == "fruitbot1":
56
+ image = cv2.resize(image, (224, 224))
57
+ image = image.reshape(-1, 224, 224, 3)
58
+
59
+ return image
60
+ except Exception as e:
61
+ logger.error(f"Error in preprocess_image: {str(e)}")
62
+ raise
63
+
64
+ @app.get("/predict")
65
+ def predict:
66
+ return JSONResponse(content={"Models Avalible For Inference at this Endpoint": ["fruitbot0", "fruitbot1"], "Models Avalible For Inference at Another Endpoint": ["recyclebot0"], "All Models": ["fruitbot0", "fruitbot1", "recyclebot0"]})
67
+
68
+ @app.post("/predict/fruitbot0")
69
+ async def predict_fruitbot0(file: UploadFile = File(...)):
70
+ try:
71
+ logger.info("Received request for /predict/fruitbot0")
72
+ img_array = preprocess_image(file.file, "fruitbot0") # Preprocess the image
73
+ prediction1 = fruitbot0.predict(img_array) # Get predictions
74
+
75
+ predicted_class_idx = np.argmax(prediction1, axis=1)[0] # Get predicted class index
76
+ predicted_class = FRUITBOT_CLASSES[predicted_class_idx] # Convert to class name
77
+
78
+ return JSONResponse(content={"prediction": predicted_class})
79
+
80
+ except Exception as e:
81
+ logger.error(f"Error in /predict: {str(e)}")
82
+ return JSONResponse(content={"error": str(e)}, status_code=400)
83
+
84
+ @app.post("/predict/fruitbot1")
85
+ async def predict_fruitbot0(file: UploadFile = File(...)):
86
+ try:
87
+ logger.info("Received request for /predict/fruitbot1")
88
+ img_array = preprocess_image(file.file, "fruitbot1") # Preprocess the image
89
+ prediction1 = fruitbot1.predict(img_array) # Get predictions
90
+
91
+ predicted_class_idx = np.argmax(prediction1, axis=1)[0] # Get predicted class index
92
+ predicted_class = FRUITBOT_CLASSES[predicted_class_idx] # Convert to class name
93
+
94
+ return JSONResponse(content={"prediction": predicted_class})
95
+
96
+ except Exception as e:
97
+ logger.error(f"Error in /predict: {str(e)}")
98
+ return JSONResponse(content={"error": str(e)}, status_code=400)
99
+
100
+ @app.post("/predict/recyclebot0")
101
+ async def predict_fruitbot0(file: UploadFile = File(...)):
102
+ return JSONResponse(content={"error": "This model is hosted at another endpoint"}, status_code=400)
103
+
104
+ @app.get("/working")
105
+ async def working():
106
+ return JSONResponse(content={"Status": "Working"})
107
+
108
+ if __name__ == "__main__":
109
+ import uvicorn
110
+ uvicorn.run(app, host="0.0.0.0", port=7860)