ahadhassan commited on
Commit
fc5324e
·
verified ·
1 Parent(s): eeac6a0

Create yolo_predictor.py

Browse files
Files changed (1) hide show
  1. yolo_predictor.py +168 -0
yolo_predictor.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # yolo_predictor.py
2
+ import os
3
+ import numpy as np
4
+ import rasterio
5
+ from ultralytics import YOLO
6
+ from ndvi_predictor import normalize_rgb, predict_ndvi
7
+ import tempfile
8
+ from rasterio.transform import from_bounds
9
+ from PIL import Image
10
+
11
+ def load_yolo_model(model_path):
12
+ """Load YOLO model from .pt file"""
13
+ return YOLO(model_path)
14
+
15
+ def predict_ndvi_from_rgb(ndvi_model, rgb_array):
16
+ """
17
+ Predict NDVI channel from RGB array
18
+
19
+ Args:
20
+ ndvi_model: Loaded NDVI prediction model
21
+ rgb_array: RGB image as numpy array (H, W, 3)
22
+
23
+ Returns:
24
+ ndvi_array: Predicted NDVI as numpy array (H, W)
25
+ """
26
+ # Normalize RGB input
27
+ norm_rgb = normalize_rgb(rgb_array)
28
+
29
+ # Predict NDVI
30
+ ndvi_pred = predict_ndvi(ndvi_model, norm_rgb)
31
+
32
+ return ndvi_pred
33
+
34
+ def predict_yolo(yolo_model, image_path, conf=0.001):
35
+ """
36
+ Predict using YOLO model on 4-channel TIFF image
37
+
38
+ Args:
39
+ yolo_model: Loaded YOLO model
40
+ image_path: Path to 4-channel TIFF image
41
+ conf: Confidence threshold
42
+
43
+ Returns:
44
+ results: YOLO results object
45
+ """
46
+ # Run YOLO prediction
47
+ results = yolo_model([image_path], conf=conf)
48
+
49
+ return results[0] # Return first result
50
+
51
+ def create_4channel_tiff(rgb_array, ndvi_array, output_path):
52
+ """
53
+ Create a 4-channel TIFF file from RGB and NDVI arrays
54
+
55
+ Args:
56
+ rgb_array: RGB image as numpy array (H, W, 3)
57
+ ndvi_array: NDVI image as numpy array (H, W)
58
+ output_path: Path to save the 4-channel TIFF
59
+ """
60
+ height, width = rgb_array.shape[:2]
61
+
62
+ # Stack RGB and NDVI to create 4-channel image
63
+ four_channel = np.zeros((height, width, 4), dtype=rgb_array.dtype)
64
+ four_channel[:, :, :3] = rgb_array # RGB channels
65
+
66
+ # Normalize NDVI to match RGB data type range
67
+ if rgb_array.dtype == np.uint8:
68
+ # Scale NDVI from [-1, 1] to [0, 255]
69
+ ndvi_scaled = ((ndvi_array + 1) * 127.5).astype(np.uint8)
70
+ else:
71
+ # Keep NDVI in original range for float types
72
+ ndvi_scaled = ndvi_array.astype(rgb_array.dtype)
73
+
74
+ four_channel[:, :, 3] = ndvi_scaled # NDVI channel
75
+
76
+ # Create transform (assuming no specific georeferencing needed)
77
+ transform = from_bounds(0, 0, width, height, width, height)
78
+
79
+ # Write 4-channel TIFF
80
+ with rasterio.open(
81
+ output_path,
82
+ 'w',
83
+ driver='GTiff',
84
+ height=height,
85
+ width=width,
86
+ count=4,
87
+ dtype=four_channel.dtype,
88
+ transform=transform
89
+ ) as dst:
90
+ for i in range(4):
91
+ dst.write(four_channel[:, :, i], i + 1)
92
+
93
+ def load_4channel_tiff(image_path):
94
+ """
95
+ Load a 4-channel TIFF image
96
+
97
+ Args:
98
+ image_path: Path to 4-channel TIFF image
99
+
100
+ Returns:
101
+ rgb_array: RGB channels as numpy array (H, W, 3)
102
+ ndvi_array: NDVI channel as numpy array (H, W)
103
+ """
104
+ with rasterio.open(image_path) as src:
105
+ # Read all 4 channels
106
+ channels = src.read() # Shape: (4, H, W)
107
+
108
+ # Extract RGB and NDVI
109
+ rgb_array = np.transpose(channels[:3], (1, 2, 0)) # (H, W, 3)
110
+ ndvi_array = channels[3] # (H, W)
111
+
112
+ # If NDVI was scaled to uint8, convert back to [-1, 1] range
113
+ if channels.dtype == np.uint8:
114
+ ndvi_array = (ndvi_array.astype(np.float32) / 127.5) - 1
115
+
116
+ return rgb_array, ndvi_array
117
+
118
+ def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
119
+ """
120
+ Full pipeline: Load 4-channel image -> Extract RGB -> Predict NDVI ->
121
+ Create new 4-channel with predicted NDVI -> Run YOLO prediction
122
+
123
+ Args:
124
+ ndvi_model: Loaded NDVI prediction model
125
+ yolo_model: Loaded YOLO model
126
+ image_path: Path to input image (can be RGB or 4-channel TIFF)
127
+ conf: Confidence threshold for YOLO
128
+
129
+ Returns:
130
+ results: YOLO results object
131
+ """
132
+ # Try to load as 4-channel TIFF first, fall back to RGB
133
+ try:
134
+ with rasterio.open(image_path) as src:
135
+ if src.count == 4:
136
+ # Load 4-channel TIFF
137
+ rgb_array, _ = load_4channel_tiff(image_path)
138
+ elif src.count == 3:
139
+ # Load as RGB TIFF
140
+ channels = src.read()
141
+ rgb_array = np.transpose(channels, (1, 2, 0))
142
+ else:
143
+ raise ValueError(f"Unsupported number of channels: {src.count}")
144
+ except:
145
+ # Fall back to PIL for standard image formats
146
+ img = Image.open(image_path).convert("RGB")
147
+ rgb_array = np.array(img)
148
+
149
+ # Predict NDVI from RGB
150
+ ndvi_pred = predict_ndvi_from_rgb(ndvi_model, rgb_array)
151
+
152
+ # Create temporary 4-channel TIFF file
153
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as tmp_file:
154
+ temp_4ch_path = tmp_file.name
155
+
156
+ try:
157
+ # Create 4-channel TIFF with predicted NDVI
158
+ create_4channel_tiff(rgb_array, ndvi_pred, temp_4ch_path)
159
+
160
+ # Run YOLO prediction on 4-channel image
161
+ results = predict_yolo(yolo_model, temp_4ch_path, conf=conf)
162
+
163
+ return results
164
+
165
+ finally:
166
+ # Clean up temporary file
167
+ if os.path.exists(temp_4ch_path):
168
+ os.unlink(temp_4ch_path)