ahadhassan commited on
Commit
3b61715
·
verified ·
1 Parent(s): ef5a350

ahad_dev (#9)

Browse files

- Added pipeline inference code (70135b4a49e7b10a9d690c0c187e251e9958813d)

Files changed (3) hide show
  1. app.py +42 -59
  2. test_yolo_api.py +16 -0
  3. yolo_predictor.py +149 -11
app.py CHANGED
@@ -162,73 +162,56 @@ async def predict_yolo_api(file: UploadFile = File(...)):
162
 
163
  @app.post("/predict_pipeline/")
164
  async def predict_pipeline_api(file: UploadFile = File(...)):
165
- """Full pipeline: RGB -> NDVI -> 4-channel -> YOLO prediction"""
166
  if ndvi_model is None or yolo_model is None:
167
  return JSONResponse(status_code=500, content={"error": "Models not loaded properly"})
168
 
169
  try:
170
- # Save uploaded file temporarily with proper extension
171
- file_extension = '.jpg'
172
- if file.filename:
173
- if file.filename.lower().endswith(('.tif', '.tiff')):
174
- file_extension = '.tiff'
175
- elif file.filename.lower().endswith(('.png', '.PNG')):
176
- file_extension = '.png'
177
- elif file.filename.lower().endswith(('.jpeg', '.jpg', '.JPG', '.JPEG')):
178
- file_extension = '.jpg'
179
 
180
- with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
181
- contents = await file.read()
182
- tmp_file.write(contents)
183
- tmp_file.flush() # Ensure data is written
184
- tmp_file_path = tmp_file.name
185
 
186
- try:
187
- # Verify the file was written correctly
188
- if not os.path.exists(tmp_file_path) or os.path.getsize(tmp_file_path) == 0:
189
- raise ValueError("Failed to create temporary file")
190
-
191
- logger.info(f"Processing pipeline for file: {file.filename}, temp path: {tmp_file_path}")
192
-
193
- # Run the full pipeline
194
- results = predict_pipeline(ndvi_model, yolo_model, tmp_file_path)
195
-
196
- # Convert results to JSON-serializable format
197
- results_dict = {
198
- "boxes": {
199
- "xyxyn": results.boxes.xyxyn.tolist() if results.boxes is not None else None,
200
- "conf": results.boxes.conf.tolist() if results.boxes is not None else None,
201
- "cls": results.boxes.cls.tolist() if results.boxes is not None else None
202
- },
203
- "classes": results.boxes.cls.tolist() if results.boxes is not None else None,
204
- "names": results.names,
205
- "orig_shape": results.orig_shape,
206
- "speed": results.speed,
207
- # "masks": {
208
- # "data": results.masks.data.tolist() if results.masks is not None else None,
209
- # "orig_shape": results.masks.orig_shape if results.masks is not None else None,
210
- # "xy": [seg.tolist() for seg in results.masks.xy] if results.masks is not None else None,
211
- # "xyn": [seg.tolist() for seg in results.masks.xyn] if results.masks is not None else None
212
- # }
213
  }
 
214
 
215
-
216
- # Handle growth stages if present in the results
217
- if hasattr(results, 'boxes') and results.boxes is not None:
218
- if hasattr(results.boxes, 'data') and len(results.boxes.data) > 0:
219
- # Check if there are additional columns for growth stages
220
- if results.boxes.data.shape[1] > 6:
221
- growth_stages = results.boxes.data[:, 6:].tolist()
222
- results_dict["growth_stages"] = growth_stages
223
-
224
- logger.info(f"Pipeline prediction completed successfully")
225
- return JSONResponse(content=results_dict)
226
-
227
- finally:
228
- # Clean up temporary file
229
- if os.path.exists(tmp_file_path):
230
- os.unlink(tmp_file_path)
231
-
232
  except Exception as e:
233
  logger.error(f"Error in predict_pipeline_api: {e}")
234
  return JSONResponse(status_code=500, content={"error": str(e)})
 
162
 
163
  @app.post("/predict_pipeline/")
164
  async def predict_pipeline_api(file: UploadFile = File(...)):
165
+ """Full pipeline: RGB -> NDVI -> 32-bit 4-channel TIFF (RGB+NDVI) -> YOLO prediction"""
166
  if ndvi_model is None or yolo_model is None:
167
  return JSONResponse(status_code=500, content={"error": "Models not loaded properly"})
168
 
169
  try:
170
+ logger.info(f"Starting full pipeline for file: {file.filename}")
 
 
 
 
 
 
 
 
171
 
172
+ # Read uploaded RGB image
173
+ contents = await file.read()
174
+ logger.info(f"Read {len(contents)} bytes from uploaded file")
 
 
175
 
176
+ # Convert to PIL Image and then to numpy array
177
+ img = Image.open(BytesIO(contents)).convert("RGB")
178
+ rgb_array = np.array(img)
179
+ logger.info(f"Converted to RGB array with shape: {rgb_array.shape}")
180
+
181
+ # Run the full pipeline
182
+ results = predict_pipeline(ndvi_model, yolo_model, rgb_array)
183
+ logger.info("Pipeline processing completed successfully")
184
+
185
+ # Convert results to JSON-serializable format
186
+ results_dict = {
187
+ "boxes": {
188
+ "xyxyn": results.boxes.xyxyn.tolist() if results.boxes is not None else None,
189
+ "conf": results.boxes.conf.tolist() if results.boxes is not None else None,
190
+ "cls": results.boxes.cls.tolist() if results.boxes is not None else None
191
+ },
192
+ "classes": results.boxes.cls.tolist() if results.boxes is not None else None,
193
+ "names": results.names,
194
+ "orig_shape": results.orig_shape,
195
+ "speed": results.speed,
196
+ "masks": {
197
+ "data": results.masks.data.tolist() if results.masks is not None else None,
198
+ "orig_shape": results.masks.orig_shape if results.masks is not None else None,
199
+ "xy": [seg.tolist() for seg in results.masks.xy] if results.masks is not None else None,
200
+ "xyn": [seg.tolist() for seg in results.masks.xyn] if results.masks is not None else None
 
 
201
  }
202
+ }
203
 
204
+ # Handle growth stages if present in the results
205
+ if hasattr(results, 'boxes') and results.boxes is not None:
206
+ if hasattr(results.boxes, 'data') and len(results.boxes.data) > 0:
207
+ # Check if there are additional columns for growth stages
208
+ if results.boxes.data.shape[1] > 6:
209
+ growth_stages = results.boxes.data[:, 6:].tolist()
210
+ results_dict["growth_stages"] = growth_stages
211
+
212
+ logger.info(f"Pipeline prediction completed successfully with {len(results_dict['boxes']['xyxyn']) if results_dict['boxes']['xyxyn'] else 0} detections")
213
+ return JSONResponse(content=results_dict)
214
+
 
 
 
 
 
 
215
  except Exception as e:
216
  logger.error(f"Error in predict_pipeline_api: {e}")
217
  return JSONResponse(status_code=500, content={"error": str(e)})
test_yolo_api.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ # Path to your 4-channel TIFF image
4
+ file_path = r"D:\AgriTech\Agri-Drone\fc_regress_train\train\images\IMG_0009.tif"
5
+
6
+ # API endpoint
7
+ url = "https://agri-tech-testing-pipeline-api.hf.space/predict_yolo/"
8
+
9
+ # Send the POST request with the image file
10
+ with open(file_path, "rb") as f:
11
+ files = {"file": (file_path.split("\\")[-1], f, "image/tiff")}
12
+ response = requests.post(url, files=files)
13
+
14
+ # Print response
15
+ print("Status Code:", response.status_code)
16
+ print("Response JSON:", response.json())
yolo_predictor.py CHANGED
@@ -1,9 +1,12 @@
1
  # yolo_predictor.py
2
  import os
3
  import logging
4
- import rasterio
5
- from ultralytics import YOLO
6
  import tifffile
 
 
 
7
 
8
  # Configure logging
9
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -11,6 +14,7 @@ logger = logging.getLogger(__name__)
11
 
12
  def load_yolo_model(model_path):
13
  """Load YOLO model from .pt file"""
 
14
  return YOLO(model_path)
15
 
16
  def predict_yolo(yolo_model, image_path, conf=0.01):
@@ -27,6 +31,28 @@ def predict_yolo(yolo_model, image_path, conf=0.01):
27
  """
28
  logger.info(f"Starting YOLO prediction on: {image_path} with confidence: {conf}")
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  logger.info("Running YOLO model inference...")
31
  # Run YOLO prediction directly on the input file
32
  results = yolo_model([image_path], conf=conf)
@@ -34,21 +60,133 @@ def predict_yolo(yolo_model, image_path, conf=0.01):
34
  logger.info(f"YOLO prediction completed. Results type: {type(results[0])}")
35
  return results[0] # Return first result
36
 
37
- def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.01):
38
  """
39
- Simplified pipeline: Validate input -> Run YOLO prediction
40
 
41
  Args:
42
- ndvi_model: Not used (kept for API compatibility)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  yolo_model: Loaded YOLO model
44
- image_path: Path to input 4-channel TIFF image
45
  conf: Confidence threshold for YOLO
46
 
47
  Returns:
48
  results: YOLO results object
49
  """
50
- logger.info(f"Starting prediction pipeline for: {image_path}")
51
- # Simply validate and run prediction on the uploaded file
52
- result = predict_yolo(yolo_model, image_path, conf=conf)
53
- logger.info("Prediction pipeline completed successfully")
54
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # yolo_predictor.py
2
  import os
3
  import logging
4
+ import tempfile
5
+ import numpy as np
6
  import tifffile
7
+ from rasterio.transform import from_bounds
8
+ from ultralytics import YOLO
9
+ from ndvi_predictor import normalize_rgb, predict_ndvi
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
14
 
15
  def load_yolo_model(model_path):
16
  """Load YOLO model from .pt file"""
17
+ logger.info(f"Loading YOLO model from: {model_path}")
18
  return YOLO(model_path)
19
 
20
  def predict_yolo(yolo_model, image_path, conf=0.01):
 
31
  """
32
  logger.info(f"Starting YOLO prediction on: {image_path} with confidence: {conf}")
33
 
34
+ # Verify file exists and has correct format
35
+ if not os.path.exists(image_path):
36
+ raise FileNotFoundError(f"Image file not found: {image_path}")
37
+
38
+ try:
39
+ # Quick validation of the TIFF file
40
+ test_array = tifffile.imread(image_path)
41
+ logger.info(f"TIFF file shape: {test_array.shape}, dtype: {test_array.dtype}")
42
+
43
+ # Validate channels
44
+ if len(test_array.shape) == 3:
45
+ channels = test_array.shape[0] if test_array.shape[0] <= 4 else test_array.shape[2]
46
+ else:
47
+ channels = 1
48
+
49
+ if channels != 4:
50
+ raise ValueError(f"Expected 4-channel image, got {channels} channels")
51
+
52
+ except Exception as e:
53
+ logger.error(f"Error validating TIFF file: {e}")
54
+ raise
55
+
56
  logger.info("Running YOLO model inference...")
57
  # Run YOLO prediction directly on the input file
58
  results = yolo_model([image_path], conf=conf)
 
60
  logger.info(f"YOLO prediction completed. Results type: {type(results[0])}")
61
  return results[0] # Return first result
62
 
63
+ def create_4channel_tiff(rgb_array, ndvi_array, output_path):
64
  """
65
+ Create a 4-channel TIFF file with RGB channels + NDVI channel
66
 
67
  Args:
68
+ rgb_array: RGB image array (H, W, 3)
69
+ ndvi_array: NDVI array (H, W) with values in [-1, 1]
70
+ output_path: Path to save the 4-channel TIFF
71
+ """
72
+ logger.info(f"Creating 4-channel TIFF file at: {output_path}")
73
+ logger.info(f"RGB shape: {rgb_array.shape}, NDVI shape: {ndvi_array.shape}")
74
+
75
+ # Ensure RGB is in uint8 format
76
+ if rgb_array.dtype != np.uint8:
77
+ if rgb_array.max() <= 1.0:
78
+ rgb_uint8 = (rgb_array * 255).astype(np.uint8)
79
+ else:
80
+ rgb_uint8 = rgb_array.astype(np.uint8)
81
+ else:
82
+ rgb_uint8 = rgb_array
83
+
84
+ # Convert NDVI from [-1, 1] to [0, 255] uint8 format (same as reference code)
85
+ ndvi_scaled = (((ndvi_array + 1) / 2) * 255).astype(np.uint8)
86
+
87
+ logger.info(f"RGB range: [{rgb_uint8.min()}, {rgb_uint8.max()}]")
88
+ logger.info(f"NDVI scaled range: [{ndvi_scaled.min()}, {ndvi_scaled.max()}]")
89
+
90
+ # Stack RGB + NDVI to create 4-channel image
91
+ # Format: (channels, height, width) - channel-first format
92
+ four_channel = np.stack([
93
+ rgb_uint8[:, :, 0], # R channel
94
+ rgb_uint8[:, :, 1], # G channel
95
+ rgb_uint8[:, :, 2], # B channel
96
+ ndvi_scaled # NDVI channel
97
+ ], axis=0)
98
+
99
+ logger.info(f"4-channel array shape: {four_channel.shape}, dtype: {four_channel.dtype}")
100
+ logger.info(f"4-channel range: [{four_channel.min()}, {four_channel.max()}]")
101
+
102
+ # Save as TIFF using tifffile
103
+ tifffile.imwrite(output_path, four_channel)
104
+ logger.info(f"Successfully saved 4-channel TIFF (RGB+NDVI format) to: {output_path}")
105
+
106
+ def predict_pipeline(ndvi_model, yolo_model, rgb_array, conf=0.01):
107
+ """
108
+ Full pipeline: RGB -> NDVI -> 32-bit 4-channel TIFF (RGB+NDVI) -> YOLO prediction
109
+
110
+ Args:
111
+ ndvi_model: Loaded NDVI prediction model
112
  yolo_model: Loaded YOLO model
113
+ rgb_array: RGB image as numpy array (H, W, 3)
114
  conf: Confidence threshold for YOLO
115
 
116
  Returns:
117
  results: YOLO results object
118
  """
119
+ logger.info("Starting full prediction pipeline")
120
+ logger.info(f"Input RGB array shape: {rgb_array.shape}, dtype: {rgb_array.dtype}")
121
+
122
+ # Step 1: Normalize RGB image
123
+ logger.info("Step 1: Normalizing RGB image")
124
+ normalized_rgb = normalize_rgb(rgb_array)
125
+ logger.info(f"Normalized RGB shape: {normalized_rgb.shape}, range: [{normalized_rgb.min():.3f}, {normalized_rgb.max():.3f}]")
126
+
127
+ # Step 2: Predict NDVI
128
+ logger.info("Step 2: Predicting NDVI from RGB")
129
+ ndvi_prediction = predict_ndvi(ndvi_model, normalized_rgb)
130
+ logger.info(f"NDVI prediction shape: {ndvi_prediction.shape}, range: [{ndvi_prediction.min():.3f}, {ndvi_prediction.max():.3f}]")
131
+
132
+ # Step 3: Create 4-channel TIFF file
133
+ logger.info("Step 3: Creating 4-channel TIFF file (BGR+NDVI)")
134
+
135
+ # Create temporary file for the 4-channel TIFF
136
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
137
+ tiff_path = tmp_file.name
138
+
139
+ try:
140
+ # Create the 4-channel TIFF
141
+ create_4channel_tiff(rgb_array, ndvi_prediction, tiff_path)
142
+
143
+ # Verify the created file
144
+ if not os.path.exists(tiff_path):
145
+ raise FileNotFoundError(f"Failed to create 4-channel TIFF at: {tiff_path}")
146
+
147
+ file_size = os.path.getsize(tiff_path)
148
+ logger.info(f"Created 4-channel TIFF file size: {file_size} bytes")
149
+
150
+ # Step 4: Run YOLO prediction on the 4-channel TIFF
151
+ logger.info("Step 4: Running YOLO prediction on 4-channel TIFF")
152
+ results = predict_yolo(yolo_model, tiff_path, conf=conf)
153
+
154
+ logger.info("Full pipeline completed successfully")
155
+ return results
156
+
157
+ except Exception as e:
158
+ logger.error(f"Error in pipeline: {e}")
159
+ raise
160
+ finally:
161
+ # Clean up temporary file
162
+ if os.path.exists(tiff_path):
163
+ try:
164
+ os.unlink(tiff_path)
165
+ logger.info(f"Cleaned up temporary file: {tiff_path}")
166
+ except Exception as cleanup_error:
167
+ logger.warning(f"Failed to clean up temporary file: {cleanup_error}")
168
+
169
+ def validate_4channel_tiff(tiff_path):
170
+ """
171
+ Validate that a TIFF file has exactly 4 channels
172
+
173
+ Args:
174
+ tiff_path: Path to TIFF file
175
+
176
+ Returns:
177
+ bool: True if valid 4-channel TIFF, False otherwise
178
+ """
179
+ try:
180
+ array = tifffile.imread(tiff_path)
181
+
182
+ if len(array.shape) == 3:
183
+ channels = array.shape[0] if array.shape[0] <= 4 else array.shape[2]
184
+ else:
185
+ channels = 1
186
+
187
+ logger.info(f"TIFF validation - Shape: {array.shape}, Channels: {channels}")
188
+ return channels == 4
189
+
190
+ except Exception as e:
191
+ logger.error(f"Error validating TIFF file: {e}")
192
+ return False