Spaces:
Sleeping
Sleeping
Add support for fruitbot-expanded
Browse files
app.py
CHANGED
|
@@ -14,9 +14,15 @@ logger = logging.getLogger(__name__)
|
|
| 14 |
|
| 15 |
tft.hftools.download_model_from_huggingface('sharktide', 'fruitbot0', 'tf_model.keras')
|
| 16 |
tft.hftools.download_model_from_huggingface('sharktide', 'fruitbot1', 'tf_model.keras')
|
|
|
|
|
|
|
| 17 |
|
| 18 |
fruitbot0 = tft.kerastools.load_from_hf_cache("sharktide", "fruitbot0", "tf_model.keras")
|
| 19 |
fruitbot1 = tft.kerastools.load_from_hf_cache("sharktide", "fruitbot1", "tf_model.keras")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
FRUITBOT_CLASSES = ['Apple 10', 'Apple 11', 'Apple 12', 'Apple 13', 'Apple 14', 'Apple 17', 'Apple 18', 'Apple 19',
|
| 22 |
'Apple 5', 'Apple 7', 'Apple 8', 'Apple 9', 'Apple Core 1', 'Apple Red Yellow 2', 'Apple worm 1',
|
|
@@ -29,6 +35,9 @@ FRUITBOT_CLASSES = ['Apple 10', 'Apple 11', 'Apple 12', 'Apple 13', 'Apple 14',
|
|
| 29 |
'apple_red_yellow_1', 'apple_rotten_1', 'cabbage_white_1', 'carrot_1', 'cucumber_1', 'cucumber_3',
|
| 30 |
'eggplant_long_1', 'pear_1', 'pear_3', 'zucchini_1', 'zucchini_dark_1']
|
| 31 |
|
|
|
|
|
|
|
|
|
|
| 32 |
# Create FastAPI app
|
| 33 |
app = FastAPI()
|
| 34 |
|
|
@@ -49,7 +58,7 @@ def preprocess_image(image_file, model):
|
|
| 49 |
# Convert image to numpy array
|
| 50 |
image = np.array(image)
|
| 51 |
|
| 52 |
-
if model == "fruitbot0":
|
| 53 |
image = cv2.resize(image, (240, 240))
|
| 54 |
image = image.reshape(-1, 240, 240, 3)
|
| 55 |
elif model == "fruitbot1":
|
|
@@ -78,7 +87,7 @@ async def predict_fruitbot0(file: UploadFile = File(...)):
|
|
| 78 |
return JSONResponse(content={"prediction": predicted_class})
|
| 79 |
|
| 80 |
except Exception as e:
|
| 81 |
-
logger.error(f"Error in /predict: {str(e)}")
|
| 82 |
return JSONResponse(content={"error": str(e)}, status_code=400)
|
| 83 |
|
| 84 |
@app.post("/predict/fruitbot1")
|
|
@@ -94,7 +103,23 @@ async def predict_fruitbot0(file: UploadFile = File(...)):
|
|
| 94 |
return JSONResponse(content={"prediction": predicted_class})
|
| 95 |
|
| 96 |
except Exception as e:
|
| 97 |
-
logger.error(f"Error in /predict: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
return JSONResponse(content={"error": str(e)}, status_code=400)
|
| 99 |
|
| 100 |
@app.post("/predict/recyclebot0")
|
|
|
|
| 14 |
|
| 15 |
tft.hftools.download_model_from_huggingface('sharktide', 'fruitbot0', 'tf_model.keras')
|
| 16 |
tft.hftools.download_model_from_huggingface('sharktide', 'fruitbot1', 'tf_model.keras')
|
| 17 |
+
tft.hftools.download_model_from_huggingface('sharktide', 'fruitbot-expanded', 'tf_model.h5')
|
| 18 |
+
|
| 19 |
|
| 20 |
fruitbot0 = tft.kerastools.load_from_hf_cache("sharktide", "fruitbot0", "tf_model.keras")
|
| 21 |
fruitbot1 = tft.kerastools.load_from_hf_cache("sharktide", "fruitbot1", "tf_model.keras")
|
| 22 |
+
fruitbot_expanded = tft.kerastools.load_from_hf_cache("sharktide", "fruitbot-expanded", "tf_model.h5")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
|
| 27 |
FRUITBOT_CLASSES = ['Apple 10', 'Apple 11', 'Apple 12', 'Apple 13', 'Apple 14', 'Apple 17', 'Apple 18', 'Apple 19',
|
| 28 |
'Apple 5', 'Apple 7', 'Apple 8', 'Apple 9', 'Apple Core 1', 'Apple Red Yellow 2', 'Apple worm 1',
|
|
|
|
| 35 |
'apple_red_yellow_1', 'apple_rotten_1', 'cabbage_white_1', 'carrot_1', 'cucumber_1', 'cucumber_3',
|
| 36 |
'eggplant_long_1', 'pear_1', 'pear_3', 'zucchini_1', 'zucchini_dark_1']
|
| 37 |
|
| 38 |
+
FRUITBOT_EXPANDED_CLASSES = ['apple', 'banana', 'beetroot', 'bell pepper', 'cabbage', 'capsicum', 'carrot', 'cauliflower', 'chilli pepper', 'corn', 'cucumber', 'eggplant', 'garlic', 'ginger', 'grapes', 'jalepeno', 'kiwi', 'lemon', 'lettuce', 'mango', 'onion', 'orange', 'paprika', 'pear', 'peas', 'pineapple', 'pomegranate', 'potato', 'raddish', 'soy beans', 'spinach', 'sweetcorn', 'sweetpotato', 'tomato', 'turnip', 'watermelon']
|
| 39 |
+
|
| 40 |
+
|
| 41 |
# Create FastAPI app
|
| 42 |
app = FastAPI()
|
| 43 |
|
|
|
|
| 58 |
# Convert image to numpy array
|
| 59 |
image = np.array(image)
|
| 60 |
|
| 61 |
+
if (model == "fruitbot0" or model == "fruitbot-expanded):
|
| 62 |
image = cv2.resize(image, (240, 240))
|
| 63 |
image = image.reshape(-1, 240, 240, 3)
|
| 64 |
elif model == "fruitbot1":
|
|
|
|
| 87 |
return JSONResponse(content={"prediction": predicted_class})
|
| 88 |
|
| 89 |
except Exception as e:
|
| 90 |
+
logger.error(f"Error in /predict/fruitbot0: {str(e)}")
|
| 91 |
return JSONResponse(content={"error": str(e)}, status_code=400)
|
| 92 |
|
| 93 |
@app.post("/predict/fruitbot1")
|
|
|
|
| 103 |
return JSONResponse(content={"prediction": predicted_class})
|
| 104 |
|
| 105 |
except Exception as e:
|
| 106 |
+
logger.error(f"Error in /predict/fruitbot1: {str(e)}")
|
| 107 |
+
return JSONResponse(content={"error": str(e)}, status_code=400)
|
| 108 |
+
|
| 109 |
+
@app.post("/predict/fruitbot-expanded")
|
| 110 |
+
async def predict_fruitbot0(file: UploadFile = File(...)):
|
| 111 |
+
try:
|
| 112 |
+
logger.info("Received request for /predict/fruitbot-expanded")
|
| 113 |
+
img_array = preprocess_image(file.file, "fruitbot-expanded") # Preprocess the image
|
| 114 |
+
prediction1 = fruitbot_expanded.predict(img_array) # Get predictions
|
| 115 |
+
|
| 116 |
+
predicted_class_idx = np.argmax(prediction1, axis=1)[0] # Get predicted class index
|
| 117 |
+
predicted_class = FRUITBOT_EXPANDED_CLASSESCLASSES[predicted_class_idx] # Convert to class name
|
| 118 |
+
|
| 119 |
+
return JSONResponse(content={"prediction": predicted_class})
|
| 120 |
+
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.error(f"Error in /predict/fruitbot-expanded: {str(e)}")
|
| 123 |
return JSONResponse(content={"error": str(e)}, status_code=400)
|
| 124 |
|
| 125 |
@app.post("/predict/recyclebot0")
|