ahadhassan commited on
Commit
05c7c93
·
verified ·
1 Parent(s): cd20fc4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -4
app.py CHANGED
@@ -11,12 +11,28 @@ import rasterio
11
  from rasterio.transform import from_bounds
12
  import tempfile
13
  import os
 
 
 
 
 
14
 
15
  app = FastAPI()
16
 
17
  # Load models at startup
18
- ndvi_model = load_model("ndvi_best_model.keras")
19
- yolo_model = load_yolo_model("4c_6c_regression.pt")
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  @app.get("/")
22
  async def root():
@@ -25,6 +41,9 @@ async def root():
25
  @app.post("/predict_ndvi/")
26
  async def predict_ndvi_api(file: UploadFile = File(...)):
27
  """Predict NDVI from RGB image"""
 
 
 
28
  try:
29
  contents = await file.read()
30
  img = Image.open(BytesIO(contents)).convert("RGB")
@@ -54,14 +73,18 @@ async def predict_ndvi_api(file: UploadFile = File(...)):
54
  headers={"Content-Disposition": "attachment; filename=ndvi_output.zip"}
55
  )
56
  except Exception as e:
 
57
  return JSONResponse(status_code=500, content={"error": str(e)})
58
 
59
  @app.post("/predict_yolo/")
60
  async def predict_yolo_api(file: UploadFile = File(...)):
61
  """Predict YOLO results from 4-channel TIFF image"""
 
 
 
62
  try:
63
  # Save uploaded file temporarily with proper extension
64
- file_extension = '.tiff' if file.filename.lower().endswith(('.tif', '.tiff')) else '.tiff'
65
 
66
  with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
67
  contents = await file.read()
@@ -74,6 +97,26 @@ async def predict_yolo_api(file: UploadFile = File(...)):
74
  if not os.path.exists(tmp_file_path) or os.path.getsize(tmp_file_path) == 0:
75
  raise ValueError("Failed to create temporary file")
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  # Predict using YOLO model
78
  results = predict_yolo(yolo_model, tmp_file_path)
79
 
@@ -98,6 +141,7 @@ async def predict_yolo_api(file: UploadFile = File(...)):
98
  growth_stages = results.boxes.data[:, 6:].tolist()
99
  results_dict["growth_stages"] = growth_stages
100
 
 
101
  return JSONResponse(content=results_dict)
102
 
103
  finally:
@@ -106,14 +150,25 @@ async def predict_yolo_api(file: UploadFile = File(...)):
106
  os.unlink(tmp_file_path)
107
 
108
  except Exception as e:
 
109
  return JSONResponse(status_code=500, content={"error": str(e)})
110
 
111
  @app.post("/predict_pipeline/")
112
  async def predict_pipeline_api(file: UploadFile = File(...)):
113
  """Full pipeline: RGB -> NDVI -> 4-channel -> YOLO prediction"""
 
 
 
114
  try:
115
  # Save uploaded file temporarily with proper extension
116
- file_extension = '.tiff' if file.filename.lower().endswith(('.tif', '.tiff')) else '.jpg'
 
 
 
 
 
 
 
117
 
118
  with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
119
  contents = await file.read()
@@ -126,6 +181,8 @@ async def predict_pipeline_api(file: UploadFile = File(...)):
126
  if not os.path.exists(tmp_file_path) or os.path.getsize(tmp_file_path) == 0:
127
  raise ValueError("Failed to create temporary file")
128
 
 
 
129
  # Run the full pipeline
130
  results = predict_pipeline(ndvi_model, yolo_model, tmp_file_path)
131
 
@@ -150,6 +207,7 @@ async def predict_pipeline_api(file: UploadFile = File(...)):
150
  growth_stages = results.boxes.data[:, 6:].tolist()
151
  results_dict["growth_stages"] = growth_stages
152
 
 
153
  return JSONResponse(content=results_dict)
154
 
155
  finally:
@@ -158,4 +216,5 @@ async def predict_pipeline_api(file: UploadFile = File(...)):
158
  os.unlink(tmp_file_path)
159
 
160
  except Exception as e:
 
161
  return JSONResponse(status_code=500, content={"error": str(e)})
 
11
  from rasterio.transform import from_bounds
12
  import tempfile
13
  import os
14
+ import logging
15
+
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
 
20
  app = FastAPI()
21
 
22
  # Load models at startup
23
+ try:
24
+ ndvi_model = load_model("ndvi_best_model.keras")
25
+ logger.info("NDVI model loaded successfully")
26
+ except Exception as e:
27
+ logger.error(f"Failed to load NDVI model: {e}")
28
+ ndvi_model = None
29
+
30
+ try:
31
+ yolo_model = load_yolo_model("4c_6c_regression.pt")
32
+ logger.info("YOLO model loaded successfully")
33
+ except Exception as e:
34
+ logger.error(f"Failed to load YOLO model: {e}")
35
+ yolo_model = None
36
 
37
  @app.get("/")
38
  async def root():
 
41
  @app.post("/predict_ndvi/")
42
  async def predict_ndvi_api(file: UploadFile = File(...)):
43
  """Predict NDVI from RGB image"""
44
+ if ndvi_model is None:
45
+ return JSONResponse(status_code=500, content={"error": "NDVI model not loaded"})
46
+
47
  try:
48
  contents = await file.read()
49
  img = Image.open(BytesIO(contents)).convert("RGB")
 
73
  headers={"Content-Disposition": "attachment; filename=ndvi_output.zip"}
74
  )
75
  except Exception as e:
76
+ logger.error(f"Error in predict_ndvi_api: {e}")
77
  return JSONResponse(status_code=500, content={"error": str(e)})
78
 
79
  @app.post("/predict_yolo/")
80
  async def predict_yolo_api(file: UploadFile = File(...)):
81
  """Predict YOLO results from 4-channel TIFF image"""
82
+ if yolo_model is None:
83
+ return JSONResponse(status_code=500, content={"error": "YOLO model not loaded"})
84
+
85
  try:
86
  # Save uploaded file temporarily with proper extension
87
+ file_extension = '.tiff' if file.filename and file.filename.lower().endswith(('.tif', '.tiff')) else '.tiff'
88
 
89
  with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
90
  contents = await file.read()
 
97
  if not os.path.exists(tmp_file_path) or os.path.getsize(tmp_file_path) == 0:
98
  raise ValueError("Failed to create temporary file")
99
 
100
+ logger.info(f"Processing YOLO prediction for file: {file.filename}, temp path: {tmp_file_path}")
101
+
102
+ # Additional validation: check if file has 4 channels
103
+ try:
104
+ import tifffile
105
+ test_array = tifffile.imread(tmp_file_path)
106
+ if len(test_array.shape) == 3:
107
+ if test_array.shape[0] == 4 or test_array.shape[2] == 4:
108
+ channels = 4
109
+ else:
110
+ channels = test_array.shape[0] if test_array.shape[0] <= 4 else test_array.shape[2]
111
+ else:
112
+ channels = 1
113
+
114
+ if channels != 4:
115
+ raise ValueError(f"YOLO model expects 4-channel images, but uploaded file has {channels} channels")
116
+
117
+ except Exception as validation_error:
118
+ logger.warning(f"Could not validate channels: {validation_error}")
119
+
120
  # Predict using YOLO model
121
  results = predict_yolo(yolo_model, tmp_file_path)
122
 
 
141
  growth_stages = results.boxes.data[:, 6:].tolist()
142
  results_dict["growth_stages"] = growth_stages
143
 
144
+ logger.info(f"YOLO prediction completed successfully")
145
  return JSONResponse(content=results_dict)
146
 
147
  finally:
 
150
  os.unlink(tmp_file_path)
151
 
152
  except Exception as e:
153
+ logger.error(f"Error in predict_yolo_api: {e}")
154
  return JSONResponse(status_code=500, content={"error": str(e)})
155
 
156
  @app.post("/predict_pipeline/")
157
  async def predict_pipeline_api(file: UploadFile = File(...)):
158
  """Full pipeline: RGB -> NDVI -> 4-channel -> YOLO prediction"""
159
+ if ndvi_model is None or yolo_model is None:
160
+ return JSONResponse(status_code=500, content={"error": "Models not loaded properly"})
161
+
162
  try:
163
  # Save uploaded file temporarily with proper extension
164
+ file_extension = '.jpg'
165
+ if file.filename:
166
+ if file.filename.lower().endswith(('.tif', '.tiff')):
167
+ file_extension = '.tiff'
168
+ elif file.filename.lower().endswith(('.png', '.PNG')):
169
+ file_extension = '.png'
170
+ elif file.filename.lower().endswith(('.jpeg', '.jpg', '.JPG', '.JPEG')):
171
+ file_extension = '.jpg'
172
 
173
  with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
174
  contents = await file.read()
 
181
  if not os.path.exists(tmp_file_path) or os.path.getsize(tmp_file_path) == 0:
182
  raise ValueError("Failed to create temporary file")
183
 
184
+ logger.info(f"Processing pipeline for file: {file.filename}, temp path: {tmp_file_path}")
185
+
186
  # Run the full pipeline
187
  results = predict_pipeline(ndvi_model, yolo_model, tmp_file_path)
188
 
 
207
  growth_stages = results.boxes.data[:, 6:].tolist()
208
  results_dict["growth_stages"] = growth_stages
209
 
210
+ logger.info(f"Pipeline prediction completed successfully")
211
  return JSONResponse(content=results_dict)
212
 
213
  finally:
 
216
  os.unlink(tmp_file_path)
217
 
218
  except Exception as e:
219
+ logger.error(f"Error in predict_pipeline_api: {e}")
220
  return JSONResponse(status_code=500, content={"error": str(e)})