ahadhassan commited on
Commit
c421799
·
verified ·
1 Parent(s): be36db0

Update yolo_predictor.py

Browse files
Files changed (1) hide show
  1. yolo_predictor.py +23 -3
yolo_predictor.py CHANGED
@@ -1,9 +1,14 @@
1
  # yolo_predictor.py
2
  import os
 
3
  import rasterio
4
  from ultralytics import YOLO
5
  import tifffile
6
 
 
 
 
 
7
  def load_yolo_model(model_path):
8
  """Load YOLO model from .pt file"""
9
  return YOLO(model_path)
@@ -22,8 +27,11 @@ def validate_4channel_tiff(image_path):
22
  ValueError: If validation fails
23
  """
24
  if not os.path.exists(image_path):
 
25
  raise ValueError(f"Image file does not exist: {image_path}")
26
 
 
 
27
  try:
28
  # Primary validation with tifffile
29
  img_array = tifffile.imread(image_path)
@@ -42,25 +50,30 @@ def validate_4channel_tiff(image_path):
42
  channels = min(img_array.shape[0], img_array.shape[2])
43
  height, width = img_array.shape[0], img_array.shape[1]
44
  else:
 
45
  raise ValueError(f"Invalid image shape: {img_array.shape}. Expected 3D array with 4 channels.")
46
 
47
  if channels != 4:
 
48
  raise ValueError(f"YOLO model expects 4-channel images, but got {channels} channels")
49
 
50
- print(f"Validation successful: {channels} channels, {height}x{width}, dtype: {img_array.dtype}")
51
  return True
52
 
53
  except Exception as e:
 
54
  # Fallback validation with rasterio
55
  try:
56
  with rasterio.open(image_path) as src:
57
  if src.count != 4:
 
58
  raise ValueError(f"YOLO model expects 4-channel images, but got {src.count} channels")
59
 
60
- print(f"Validation successful (rasterio): {src.count} channels, {src.width}x{src.height}, dtype: {src.dtypes[0]}")
61
  return True
62
 
63
  except Exception as e2:
 
64
  raise ValueError(f"Could not validate TIFF file. Errors: tifffile={str(e)}, rasterio={str(e2)}")
65
 
66
  def predict_yolo(yolo_model, image_path, conf=0.001):
@@ -75,12 +88,16 @@ def predict_yolo(yolo_model, image_path, conf=0.001):
75
  Returns:
76
  results: YOLO results object
77
  """
 
 
78
  # Validate input file
79
  validate_4channel_tiff(image_path)
80
 
 
81
  # Run YOLO prediction directly on the input file
82
  results = yolo_model([image_path], conf=conf)
83
 
 
84
  return results[0] # Return first result
85
 
86
  def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
@@ -96,5 +113,8 @@ def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
96
  Returns:
97
  results: YOLO results object
98
  """
 
99
  # Simply validate and run prediction on the uploaded file
100
- return predict_yolo(yolo_model, image_path, conf=conf)
 
 
 
1
  # yolo_predictor.py
2
  import os
3
+ import logging
4
  import rasterio
5
  from ultralytics import YOLO
6
  import tifffile
7
 
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
10
+ logger = logging.getLogger(__name__)
11
+
12
  def load_yolo_model(model_path):
13
  """Load YOLO model from .pt file"""
14
  return YOLO(model_path)
 
27
  ValueError: If validation fails
28
  """
29
  if not os.path.exists(image_path):
30
+ logger.error(f"Image file does not exist: {image_path}")
31
  raise ValueError(f"Image file does not exist: {image_path}")
32
 
33
+ logger.info(f"Validating TIFF file: {image_path}")
34
+
35
  try:
36
  # Primary validation with tifffile
37
  img_array = tifffile.imread(image_path)
 
50
  channels = min(img_array.shape[0], img_array.shape[2])
51
  height, width = img_array.shape[0], img_array.shape[1]
52
  else:
53
+ logger.error(f"Invalid image shape: {img_array.shape}. Expected 3D array with 4 channels.")
54
  raise ValueError(f"Invalid image shape: {img_array.shape}. Expected 3D array with 4 channels.")
55
 
56
  if channels != 4:
57
+ logger.error(f"YOLO model expects 4-channel images, but got {channels} channels")
58
  raise ValueError(f"YOLO model expects 4-channel images, but got {channels} channels")
59
 
60
+ logger.info(f"Validation successful: {channels} channels, {height}x{width}, dtype: {img_array.dtype}")
61
  return True
62
 
63
  except Exception as e:
64
+ logger.warning(f"Tifffile validation failed: {str(e)}, trying rasterio fallback")
65
  # Fallback validation with rasterio
66
  try:
67
  with rasterio.open(image_path) as src:
68
  if src.count != 4:
69
+ logger.error(f"YOLO model expects 4-channel images, but got {src.count} channels")
70
  raise ValueError(f"YOLO model expects 4-channel images, but got {src.count} channels")
71
 
72
+ logger.info(f"Validation successful (rasterio): {src.count} channels, {src.width}x{src.height}, dtype: {src.dtypes[0]}")
73
  return True
74
 
75
  except Exception as e2:
76
+ logger.error(f"Could not validate TIFF file. Tifffile error: {str(e)}, Rasterio error: {str(e2)}")
77
  raise ValueError(f"Could not validate TIFF file. Errors: tifffile={str(e)}, rasterio={str(e2)}")
78
 
79
  def predict_yolo(yolo_model, image_path, conf=0.001):
 
88
  Returns:
89
  results: YOLO results object
90
  """
91
+ logger.info(f"Starting YOLO prediction on: {image_path} with confidence: {conf}")
92
+
93
  # Validate input file
94
  validate_4channel_tiff(image_path)
95
 
96
+ logger.info("Running YOLO model inference...")
97
  # Run YOLO prediction directly on the input file
98
  results = yolo_model([image_path], conf=conf)
99
 
100
+ logger.info(f"YOLO prediction completed. Results type: {type(results[0])}")
101
  return results[0] # Return first result
102
 
103
  def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
 
113
  Returns:
114
  results: YOLO results object
115
  """
116
+ logger.info(f"Starting prediction pipeline for: {image_path}")
117
  # Simply validate and run prediction on the uploaded file
118
+ result = predict_yolo(yolo_model, image_path, conf=conf)
119
+ logger.info("Prediction pipeline completed successfully")
120
+ return result