sharktide commited on
Commit
ea16666
·
verified ·
1 Parent(s): 33907db

Add support for fruitbot-expanded

Browse files
Files changed (1) hide show
  1. app.py +28 -3
app.py CHANGED
@@ -14,9 +14,15 @@ logger = logging.getLogger(__name__)
14
 
15
  tft.hftools.download_model_from_huggingface('sharktide', 'fruitbot0', 'tf_model.keras')
16
  tft.hftools.download_model_from_huggingface('sharktide', 'fruitbot1', 'tf_model.keras')
 
 
17
 
18
  fruitbot0 = tft.kerastools.load_from_hf_cache("sharktide", "fruitbot0", "tf_model.keras")
19
  fruitbot1 = tft.kerastools.load_from_hf_cache("sharktide", "fruitbot1", "tf_model.keras")
 
 
 
 
20
 
21
  FRUITBOT_CLASSES = ['Apple 10', 'Apple 11', 'Apple 12', 'Apple 13', 'Apple 14', 'Apple 17', 'Apple 18', 'Apple 19',
22
  'Apple 5', 'Apple 7', 'Apple 8', 'Apple 9', 'Apple Core 1', 'Apple Red Yellow 2', 'Apple worm 1',
@@ -29,6 +35,9 @@ FRUITBOT_CLASSES = ['Apple 10', 'Apple 11', 'Apple 12', 'Apple 13', 'Apple 14',
29
  'apple_red_yellow_1', 'apple_rotten_1', 'cabbage_white_1', 'carrot_1', 'cucumber_1', 'cucumber_3',
30
  'eggplant_long_1', 'pear_1', 'pear_3', 'zucchini_1', 'zucchini_dark_1']
31
 
 
 
 
32
  # Create FastAPI app
33
  app = FastAPI()
34
 
@@ -49,7 +58,7 @@ def preprocess_image(image_file, model):
49
  # Convert image to numpy array
50
  image = np.array(image)
51
 
52
- if model == "fruitbot0":
53
  image = cv2.resize(image, (240, 240))
54
  image = image.reshape(-1, 240, 240, 3)
55
  elif model == "fruitbot1":
@@ -78,7 +87,7 @@ async def predict_fruitbot0(file: UploadFile = File(...)):
78
  return JSONResponse(content={"prediction": predicted_class})
79
 
80
  except Exception as e:
81
- logger.error(f"Error in /predict: {str(e)}")
82
  return JSONResponse(content={"error": str(e)}, status_code=400)
83
 
84
  @app.post("/predict/fruitbot1")
@@ -94,7 +103,23 @@ async def predict_fruitbot0(file: UploadFile = File(...)):
94
  return JSONResponse(content={"prediction": predicted_class})
95
 
96
  except Exception as e:
97
- logger.error(f"Error in /predict: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  return JSONResponse(content={"error": str(e)}, status_code=400)
99
 
100
  @app.post("/predict/recyclebot0")
 
14
 
15
  tft.hftools.download_model_from_huggingface('sharktide', 'fruitbot0', 'tf_model.keras')
16
  tft.hftools.download_model_from_huggingface('sharktide', 'fruitbot1', 'tf_model.keras')
17
+ tft.hftools.download_model_from_huggingface('sharktide', 'fruitbot-expanded', 'tf_model.h5')
18
+
19
 
20
  fruitbot0 = tft.kerastools.load_from_hf_cache("sharktide", "fruitbot0", "tf_model.keras")
21
  fruitbot1 = tft.kerastools.load_from_hf_cache("sharktide", "fruitbot1", "tf_model.keras")
22
+ fruitbot_expanded = tft.kerastools.load_from_hf_cache("sharktide", "fruitbot-expanded", "tf_model.h5")
23
+
24
+
25
+
26
 
27
  FRUITBOT_CLASSES = ['Apple 10', 'Apple 11', 'Apple 12', 'Apple 13', 'Apple 14', 'Apple 17', 'Apple 18', 'Apple 19',
28
  'Apple 5', 'Apple 7', 'Apple 8', 'Apple 9', 'Apple Core 1', 'Apple Red Yellow 2', 'Apple worm 1',
 
35
  'apple_red_yellow_1', 'apple_rotten_1', 'cabbage_white_1', 'carrot_1', 'cucumber_1', 'cucumber_3',
36
  'eggplant_long_1', 'pear_1', 'pear_3', 'zucchini_1', 'zucchini_dark_1']
37
 
38
+ FRUITBOT_EXPANDED_CLASSES = ['apple', 'banana', 'beetroot', 'bell pepper', 'cabbage', 'capsicum', 'carrot', 'cauliflower', 'chilli pepper', 'corn', 'cucumber', 'eggplant', 'garlic', 'ginger', 'grapes', 'jalepeno', 'kiwi', 'lemon', 'lettuce', 'mango', 'onion', 'orange', 'paprika', 'pear', 'peas', 'pineapple', 'pomegranate', 'potato', 'raddish', 'soy beans', 'spinach', 'sweetcorn', 'sweetpotato', 'tomato', 'turnip', 'watermelon']
39
+
40
+
41
  # Create FastAPI app
42
  app = FastAPI()
43
 
 
58
  # Convert image to numpy array
59
  image = np.array(image)
60
 
61
+ if (model == "fruitbot0" or model == "fruitbot-expanded):
62
  image = cv2.resize(image, (240, 240))
63
  image = image.reshape(-1, 240, 240, 3)
64
  elif model == "fruitbot1":
 
87
  return JSONResponse(content={"prediction": predicted_class})
88
 
89
  except Exception as e:
90
+ logger.error(f"Error in /predict/fruitbot0: {str(e)}")
91
  return JSONResponse(content={"error": str(e)}, status_code=400)
92
 
93
  @app.post("/predict/fruitbot1")
 
103
  return JSONResponse(content={"prediction": predicted_class})
104
 
105
  except Exception as e:
106
+ logger.error(f"Error in /predict/fruitbot1: {str(e)}")
107
+ return JSONResponse(content={"error": str(e)}, status_code=400)
108
+
109
+ @app.post("/predict/fruitbot-expanded")
110
+ async def predict_fruitbot0(file: UploadFile = File(...)):
111
+ try:
112
+ logger.info("Received request for /predict/fruitbot-expanded")
113
+ img_array = preprocess_image(file.file, "fruitbot-expanded") # Preprocess the image
114
+ prediction1 = fruitbot_expanded.predict(img_array) # Get predictions
115
+
116
+ predicted_class_idx = np.argmax(prediction1, axis=1)[0] # Get predicted class index
117
+ predicted_class = FRUITBOT_EXPANDED_CLASSESCLASSES[predicted_class_idx] # Convert to class name
118
+
119
+ return JSONResponse(content={"prediction": predicted_class})
120
+
121
+ except Exception as e:
122
+ logger.error(f"Error in /predict/fruitbot-expanded: {str(e)}")
123
  return JSONResponse(content={"error": str(e)}, status_code=400)
124
 
125
  @app.post("/predict/recyclebot0")