Spaces:
Sleeping
Sleeping
Update yolo_predictor.py
Browse files- yolo_predictor.py +23 -3
yolo_predictor.py
CHANGED
|
@@ -1,9 +1,14 @@
|
|
| 1 |
# yolo_predictor.py
|
| 2 |
import os
|
|
|
|
| 3 |
import rasterio
|
| 4 |
from ultralytics import YOLO
|
| 5 |
import tifffile
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
def load_yolo_model(model_path):
|
| 8 |
"""Load YOLO model from .pt file"""
|
| 9 |
return YOLO(model_path)
|
|
@@ -22,8 +27,11 @@ def validate_4channel_tiff(image_path):
|
|
| 22 |
ValueError: If validation fails
|
| 23 |
"""
|
| 24 |
if not os.path.exists(image_path):
|
|
|
|
| 25 |
raise ValueError(f"Image file does not exist: {image_path}")
|
| 26 |
|
|
|
|
|
|
|
| 27 |
try:
|
| 28 |
# Primary validation with tifffile
|
| 29 |
img_array = tifffile.imread(image_path)
|
|
@@ -42,25 +50,30 @@ def validate_4channel_tiff(image_path):
|
|
| 42 |
channels = min(img_array.shape[0], img_array.shape[2])
|
| 43 |
height, width = img_array.shape[0], img_array.shape[1]
|
| 44 |
else:
|
|
|
|
| 45 |
raise ValueError(f"Invalid image shape: {img_array.shape}. Expected 3D array with 4 channels.")
|
| 46 |
|
| 47 |
if channels != 4:
|
|
|
|
| 48 |
raise ValueError(f"YOLO model expects 4-channel images, but got {channels} channels")
|
| 49 |
|
| 50 |
-
|
| 51 |
return True
|
| 52 |
|
| 53 |
except Exception as e:
|
|
|
|
| 54 |
# Fallback validation with rasterio
|
| 55 |
try:
|
| 56 |
with rasterio.open(image_path) as src:
|
| 57 |
if src.count != 4:
|
|
|
|
| 58 |
raise ValueError(f"YOLO model expects 4-channel images, but got {src.count} channels")
|
| 59 |
|
| 60 |
-
|
| 61 |
return True
|
| 62 |
|
| 63 |
except Exception as e2:
|
|
|
|
| 64 |
raise ValueError(f"Could not validate TIFF file. Errors: tifffile={str(e)}, rasterio={str(e2)}")
|
| 65 |
|
| 66 |
def predict_yolo(yolo_model, image_path, conf=0.001):
|
|
@@ -75,12 +88,16 @@ def predict_yolo(yolo_model, image_path, conf=0.001):
|
|
| 75 |
Returns:
|
| 76 |
results: YOLO results object
|
| 77 |
"""
|
|
|
|
|
|
|
| 78 |
# Validate input file
|
| 79 |
validate_4channel_tiff(image_path)
|
| 80 |
|
|
|
|
| 81 |
# Run YOLO prediction directly on the input file
|
| 82 |
results = yolo_model([image_path], conf=conf)
|
| 83 |
|
|
|
|
| 84 |
return results[0] # Return first result
|
| 85 |
|
| 86 |
def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
|
|
@@ -96,5 +113,8 @@ def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
|
|
| 96 |
Returns:
|
| 97 |
results: YOLO results object
|
| 98 |
"""
|
|
|
|
| 99 |
# Simply validate and run prediction on the uploaded file
|
| 100 |
-
|
|
|
|
|
|
|
|
|
| 1 |
# yolo_predictor.py
|
| 2 |
import os
|
| 3 |
+
import logging
|
| 4 |
import rasterio
|
| 5 |
from ultralytics import YOLO
|
| 6 |
import tifffile
|
| 7 |
|
| 8 |
+
# Configure logging
|
| 9 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
def load_yolo_model(model_path):
|
| 13 |
"""Load YOLO model from .pt file"""
|
| 14 |
return YOLO(model_path)
|
|
|
|
| 27 |
ValueError: If validation fails
|
| 28 |
"""
|
| 29 |
if not os.path.exists(image_path):
|
| 30 |
+
logger.error(f"Image file does not exist: {image_path}")
|
| 31 |
raise ValueError(f"Image file does not exist: {image_path}")
|
| 32 |
|
| 33 |
+
logger.info(f"Validating TIFF file: {image_path}")
|
| 34 |
+
|
| 35 |
try:
|
| 36 |
# Primary validation with tifffile
|
| 37 |
img_array = tifffile.imread(image_path)
|
|
|
|
| 50 |
channels = min(img_array.shape[0], img_array.shape[2])
|
| 51 |
height, width = img_array.shape[0], img_array.shape[1]
|
| 52 |
else:
|
| 53 |
+
logger.error(f"Invalid image shape: {img_array.shape}. Expected 3D array with 4 channels.")
|
| 54 |
raise ValueError(f"Invalid image shape: {img_array.shape}. Expected 3D array with 4 channels.")
|
| 55 |
|
| 56 |
if channels != 4:
|
| 57 |
+
logger.error(f"YOLO model expects 4-channel images, but got {channels} channels")
|
| 58 |
raise ValueError(f"YOLO model expects 4-channel images, but got {channels} channels")
|
| 59 |
|
| 60 |
+
logger.info(f"Validation successful: {channels} channels, {height}x{width}, dtype: {img_array.dtype}")
|
| 61 |
return True
|
| 62 |
|
| 63 |
except Exception as e:
|
| 64 |
+
logger.warning(f"Tifffile validation failed: {str(e)}, trying rasterio fallback")
|
| 65 |
# Fallback validation with rasterio
|
| 66 |
try:
|
| 67 |
with rasterio.open(image_path) as src:
|
| 68 |
if src.count != 4:
|
| 69 |
+
logger.error(f"YOLO model expects 4-channel images, but got {src.count} channels")
|
| 70 |
raise ValueError(f"YOLO model expects 4-channel images, but got {src.count} channels")
|
| 71 |
|
| 72 |
+
logger.info(f"Validation successful (rasterio): {src.count} channels, {src.width}x{src.height}, dtype: {src.dtypes[0]}")
|
| 73 |
return True
|
| 74 |
|
| 75 |
except Exception as e2:
|
| 76 |
+
logger.error(f"Could not validate TIFF file. Tifffile error: {str(e)}, Rasterio error: {str(e2)}")
|
| 77 |
raise ValueError(f"Could not validate TIFF file. Errors: tifffile={str(e)}, rasterio={str(e2)}")
|
| 78 |
|
| 79 |
def predict_yolo(yolo_model, image_path, conf=0.001):
|
|
|
|
| 88 |
Returns:
|
| 89 |
results: YOLO results object
|
| 90 |
"""
|
| 91 |
+
logger.info(f"Starting YOLO prediction on: {image_path} with confidence: {conf}")
|
| 92 |
+
|
| 93 |
# Validate input file
|
| 94 |
validate_4channel_tiff(image_path)
|
| 95 |
|
| 96 |
+
logger.info("Running YOLO model inference...")
|
| 97 |
# Run YOLO prediction directly on the input file
|
| 98 |
results = yolo_model([image_path], conf=conf)
|
| 99 |
|
| 100 |
+
logger.info(f"YOLO prediction completed. Results type: {type(results[0])}")
|
| 101 |
return results[0] # Return first result
|
| 102 |
|
| 103 |
def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
|
|
|
|
| 113 |
Returns:
|
| 114 |
results: YOLO results object
|
| 115 |
"""
|
| 116 |
+
logger.info(f"Starting prediction pipeline for: {image_path}")
|
| 117 |
# Simply validate and run prediction on the uploaded file
|
| 118 |
+
result = predict_yolo(yolo_model, image_path, conf=conf)
|
| 119 |
+
logger.info("Prediction pipeline completed successfully")
|
| 120 |
+
return result
|