yasyn14 commited on
Commit
82bac8f
·
1 Parent(s): 7a11d7a

changed to single image sending

Browse files
Files changed (1) hide show
  1. main.py +32 -52
main.py CHANGED
@@ -1,12 +1,11 @@
1
  import os
2
  import logging
3
- from typing import List, Optional
4
  from contextlib import asynccontextmanager
5
 
6
  import numpy as np
7
  import tensorflow as tf
8
  from fastapi import FastAPI, File, UploadFile, HTTPException, status
9
- from fastapi.responses import JSONResponse
10
  from PIL import Image
11
  import io
12
  from huggingface_hub import hf_hub_download
@@ -21,7 +20,6 @@ HF_MODEL_REPO: str = os.getenv("HF_MODEL_REPO", "yasyn14/smart-leaf-model")
21
  HF_MODEL_FILENAME: str = os.getenv("HF_MODEL_FILENAME", "best_model_32epochs.keras")
22
  HF_CACHE_DIR: str = os.getenv("HF_HOME", "/home/appuser/huggingface")
23
  IMAGE_SIZE: tuple = (300, 300)
24
- MAX_BATCH_SIZE: int = 10
25
 
26
  # Plant disease class names
27
  CLASS_NAMES = [
@@ -47,7 +45,6 @@ CLEAN_CLASS_NAMES = [name.replace('___', ' - ').replace('_', ' ') for name in CL
47
  HTTP_MESSAGES = {
48
  "MODEL_NOT_LOADED": "Model not loaded. Please check server logs.",
49
  "INVALID_FILE_TYPE": "File must be an image",
50
- "BATCH_SIZE_EXCEEDED": "Maximum {max_size} images allowed per batch",
51
  "PREDICTION_FAILED": "Prediction failed: {error}",
52
  "IMAGE_PROCESSING_FAILED": "Error preprocessing image: {error}",
53
  "MODEL_LOAD_SUCCESS": "Model loaded successfully",
@@ -58,15 +55,12 @@ HTTP_MESSAGES = {
58
  model: Optional[tf.keras.Model] = None
59
 
60
  # Response models
61
- class PredictionResult(BaseModel):
 
62
  predicted_class: str
63
  clean_class_name: str
64
  confidence: float
65
  all_predictions: dict
66
-
67
- class PredictionResponse(BaseModel):
68
- success: bool
69
- results: List[PredictionResult]
70
  message: str
71
 
72
  class HealthResponse(BaseModel):
@@ -124,8 +118,8 @@ def preprocess_image(image_bytes: bytes) -> np.ndarray:
124
  logger.error(f"Error preprocessing image: {str(e)}")
125
  raise
126
 
127
- def predict_single_image(image_bytes: bytes) -> PredictionResult:
128
- """Make prediction for a single image"""
129
  global model
130
 
131
  if model is None:
@@ -153,11 +147,13 @@ def predict_single_image(image_bytes: bytes) -> PredictionResult:
153
  for i in range(len(CLASS_NAMES))
154
  }
155
 
156
- return PredictionResult(
 
157
  predicted_class=predicted_class,
158
  clean_class_name=clean_class_name,
159
  confidence=confidence,
160
- all_predictions=all_predictions
 
161
  )
162
 
163
  except Exception as e:
@@ -198,7 +194,7 @@ async def lifespan(app: FastAPI):
198
  # Create FastAPI app
199
  app = FastAPI(
200
  title="Plant Disease Prediction API",
201
- description="API for predicting plant diseases from leaf images using deep learning",
202
  version="1.0.0",
203
  lifespan=lifespan
204
  )
@@ -222,54 +218,38 @@ async def health_check():
222
  )
223
 
224
  @app.post("/predict", response_model=PredictionResponse)
225
- async def predict_plant_disease(files: List[UploadFile] = File(...)):
226
  """
227
- Predict plant disease from uploaded image(s)
228
 
229
- - **files**: List of image files to analyze (max 10 files)
230
 
231
- Returns predictions with confidence scores for each image
232
  """
233
 
234
- # Check batch size
235
- if len(files) > MAX_BATCH_SIZE:
236
  raise HTTPException(
237
  status_code=status.HTTP_400_BAD_REQUEST,
238
- detail=HTTP_MESSAGES["BATCH_SIZE_EXCEEDED"].format(max_size=MAX_BATCH_SIZE)
239
  )
240
 
241
- results = []
242
-
243
- for file in files:
244
- # Check if file is an image
245
- if not is_image_file(file.filename):
246
- raise HTTPException(
247
- status_code=status.HTTP_400_BAD_REQUEST,
248
- detail=f"{HTTP_MESSAGES['INVALID_FILE_TYPE']}: {file.filename}"
249
- )
250
 
251
- try:
252
- # Read file content
253
- image_bytes = await file.read()
254
-
255
- # Make prediction
256
- result = predict_single_image(image_bytes)
257
- results.append(result)
258
-
259
- except HTTPException:
260
- raise
261
- except Exception as e:
262
- logger.error(f"Error processing file {file.filename}: {str(e)}")
263
- raise HTTPException(
264
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
265
- detail=HTTP_MESSAGES["IMAGE_PROCESSING_FAILED"].format(error=str(e))
266
- )
267
-
268
- return PredictionResponse(
269
- success=True,
270
- results=results,
271
- message=f"Successfully processed {len(results)} image(s)"
272
- )
273
 
274
  @app.get("/classes")
275
  async def get_classes():
 
1
  import os
2
  import logging
3
+ from typing import Optional
4
  from contextlib import asynccontextmanager
5
 
6
  import numpy as np
7
  import tensorflow as tf
8
  from fastapi import FastAPI, File, UploadFile, HTTPException, status
 
9
  from PIL import Image
10
  import io
11
  from huggingface_hub import hf_hub_download
 
20
  HF_MODEL_FILENAME: str = os.getenv("HF_MODEL_FILENAME", "best_model_32epochs.keras")
21
  HF_CACHE_DIR: str = os.getenv("HF_HOME", "/home/appuser/huggingface")
22
  IMAGE_SIZE: tuple = (300, 300)
 
23
 
24
  # Plant disease class names
25
  CLASS_NAMES = [
 
45
  HTTP_MESSAGES = {
46
  "MODEL_NOT_LOADED": "Model not loaded. Please check server logs.",
47
  "INVALID_FILE_TYPE": "File must be an image",
 
48
  "PREDICTION_FAILED": "Prediction failed: {error}",
49
  "IMAGE_PROCESSING_FAILED": "Error preprocessing image: {error}",
50
  "MODEL_LOAD_SUCCESS": "Model loaded successfully",
 
55
  model: Optional[tf.keras.Model] = None
56
 
57
  # Response models
58
+ class PredictionResponse(BaseModel):
59
+ success: bool
60
  predicted_class: str
61
  clean_class_name: str
62
  confidence: float
63
  all_predictions: dict
 
 
 
 
64
  message: str
65
 
66
  class HealthResponse(BaseModel):
 
118
  logger.error(f"Error preprocessing image: {str(e)}")
119
  raise
120
 
121
+ def predict_image(image_bytes: bytes) -> PredictionResponse:
122
+ """Make prediction for the uploaded image"""
123
  global model
124
 
125
  if model is None:
 
147
  for i in range(len(CLASS_NAMES))
148
  }
149
 
150
+ return PredictionResponse(
151
+ success=True,
152
  predicted_class=predicted_class,
153
  clean_class_name=clean_class_name,
154
  confidence=confidence,
155
+ all_predictions=all_predictions,
156
+ message="Image processed successfully"
157
  )
158
 
159
  except Exception as e:
 
194
  # Create FastAPI app
195
  app = FastAPI(
196
  title="Plant Disease Prediction API",
197
+ description="API for predicting plant diseases from a single leaf image using deep learning",
198
  version="1.0.0",
199
  lifespan=lifespan
200
  )
 
218
  )
219
 
220
  @app.post("/predict", response_model=PredictionResponse)
221
+ async def predict_plant_disease(file: UploadFile = File(...)):
222
  """
223
+ Predict plant disease from uploaded image
224
 
225
+ - **file**: Single image file to analyze
226
 
227
+ Returns prediction with confidence score for the image
228
  """
229
 
230
+ # Check if file is an image
231
+ if not is_image_file(file.filename):
232
  raise HTTPException(
233
  status_code=status.HTTP_400_BAD_REQUEST,
234
+ detail=f"{HTTP_MESSAGES['INVALID_FILE_TYPE']}: {file.filename}"
235
  )
236
 
237
+ try:
238
+ # Read file content
239
+ image_bytes = await file.read()
 
 
 
 
 
 
240
 
241
+ # Make prediction
242
+ result = predict_image(image_bytes)
243
+ return result
244
+
245
+ except HTTPException:
246
+ raise
247
+ except Exception as e:
248
+ logger.error(f"Error processing file {file.filename}: {str(e)}")
249
+ raise HTTPException(
250
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
251
+ detail=HTTP_MESSAGES["IMAGE_PROCESSING_FAILED"].format(error=str(e))
252
+ )
 
 
 
 
 
 
 
 
 
 
253
 
254
  @app.get("/classes")
255
  async def get_classes():