Spaces:
Sleeping
Sleeping
Update yolo_predictor.py
Browse files- yolo_predictor.py +62 -44
yolo_predictor.py
CHANGED
|
@@ -5,7 +5,6 @@ import rasterio
|
|
| 5 |
from ultralytics import YOLO
|
| 6 |
from ndvi_predictor import normalize_rgb, predict_ndvi
|
| 7 |
import tempfile
|
| 8 |
-
from rasterio.transform import from_bounds
|
| 9 |
from PIL import Image
|
| 10 |
import tifffile
|
| 11 |
|
|
@@ -46,38 +45,60 @@ def predict_yolo(yolo_model, image_path, conf=0.001):
|
|
| 46 |
"""
|
| 47 |
# Verify the image has 4 channels before prediction
|
| 48 |
try:
|
| 49 |
-
#
|
| 50 |
-
|
| 51 |
-
if hasattr(img, 'n_frames'):
|
| 52 |
-
# Multi-frame TIFF
|
| 53 |
-
channels = img.n_frames
|
| 54 |
-
else:
|
| 55 |
-
# Regular image
|
| 56 |
-
channels = len(img.getbands()) if hasattr(img, 'getbands') else 3
|
| 57 |
|
| 58 |
-
#
|
| 59 |
-
if
|
| 60 |
-
img_array =
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
channels = img_array.shape[0] if img_array.shape[0] <= 4 else img_array.shape[2]
|
| 68 |
-
else:
|
| 69 |
-
channels = 1
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
-
|
|
|
|
| 81 |
|
| 82 |
def create_4channel_tiff(rgb_array, ndvi_array, output_path):
|
| 83 |
"""
|
|
@@ -90,7 +111,7 @@ def create_4channel_tiff(rgb_array, ndvi_array, output_path):
|
|
| 90 |
"""
|
| 91 |
height, width = rgb_array.shape[:2]
|
| 92 |
|
| 93 |
-
# Ensure RGB is in uint8 format
|
| 94 |
if rgb_array.dtype != np.uint8:
|
| 95 |
if rgb_array.max() <= 1.0:
|
| 96 |
rgb_normalized = (rgb_array * 255).astype(np.uint8)
|
|
@@ -109,13 +130,14 @@ def create_4channel_tiff(rgb_array, ndvi_array, output_path):
|
|
| 109 |
four_channel[:, :, 2] = rgb_normalized[:, :, 2] # Blue
|
| 110 |
four_channel[:, :, 3] = ndvi_normalized # NDVI
|
| 111 |
|
| 112 |
-
# Save using tifffile with
|
| 113 |
tifffile.imwrite(
|
| 114 |
-
output_path,
|
| 115 |
-
four_channel,
|
| 116 |
photometric='rgb',
|
| 117 |
compress='lzw',
|
| 118 |
-
metadata={'axes': 'YXC'}
|
|
|
|
| 119 |
)
|
| 120 |
|
| 121 |
def load_4channel_tiff(image_path):
|
|
@@ -225,7 +247,7 @@ def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
|
|
| 225 |
except Exception as e2:
|
| 226 |
try:
|
| 227 |
# Method 3: Fall back to PIL for standard formats
|
| 228 |
-
img = Image.
|
| 229 |
if img.mode != 'RGB':
|
| 230 |
img = img.convert('RGB')
|
| 231 |
rgb_array = np.array(img)
|
|
@@ -238,12 +260,10 @@ def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
|
|
| 238 |
|
| 239 |
# Ensure RGB is in correct format and range
|
| 240 |
if rgb_array.dtype == np.uint8:
|
| 241 |
-
# Keep as uint8 but also create float version for NDVI prediction
|
| 242 |
rgb_float = rgb_array.astype(np.float32) / 255.0
|
| 243 |
else:
|
| 244 |
-
# Already float, ensure range is [0, 1]
|
| 245 |
if rgb_array.max() > 1.0:
|
| 246 |
-
rgb_float = rgb_array /
|
| 247 |
else:
|
| 248 |
rgb_float = rgb_array
|
| 249 |
rgb_array = (rgb_float * 255).astype(np.uint8)
|
|
@@ -261,13 +281,11 @@ def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
|
|
| 261 |
|
| 262 |
# Verify the created file can be read
|
| 263 |
try:
|
| 264 |
-
|
| 265 |
-
if
|
| 266 |
-
channels =
|
| 267 |
else:
|
| 268 |
-
channels = len(
|
| 269 |
-
test_img.close()
|
| 270 |
-
|
| 271 |
if channels != 4:
|
| 272 |
raise ValueError(f"Created TIFF has {channels} channels instead of 4")
|
| 273 |
|
|
|
|
| 5 |
from ultralytics import YOLO
|
| 6 |
from ndvi_predictor import normalize_rgb, predict_ndvi
|
| 7 |
import tempfile
|
|
|
|
| 8 |
from PIL import Image
|
| 9 |
import tifffile
|
| 10 |
|
|
|
|
| 45 |
"""
|
| 46 |
# Verify the image has 4 channels before prediction
|
| 47 |
try:
|
| 48 |
+
# Use tifffile for 32-bit TIFF support
|
| 49 |
+
img_array = tifffile.imread(image_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
+
# Handle different array shapes
|
| 52 |
+
if len(img_array.shape) == 3:
|
| 53 |
+
if img_array.shape[0] == 4:
|
| 54 |
+
# Shape is (4, H, W) - transpose to (H, W, 4)
|
| 55 |
+
img_array = np.transpose(img_array, (1, 2, 0))
|
| 56 |
+
elif img_array.shape[2] != 4:
|
| 57 |
+
raise ValueError(f"YOLO model expects 4-channel images, but got {img_array.shape[2]} channels")
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"Unexpected image shape: {img_array.shape}")
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
+
# Convert 32-bit float to uint8 if necessary
|
| 62 |
+
if img_array.dtype != np.uint8:
|
| 63 |
+
# Normalize to [0, 255] for RGB channels (first 3)
|
| 64 |
+
rgb_array = img_array[:, :, :3]
|
| 65 |
+
if rgb_array.max() > 1.0:
|
| 66 |
+
rgb_array = np.clip(rgb_array / rgb_array.max() * 255, 0, 255).astype(np.uint8)
|
| 67 |
+
else:
|
| 68 |
+
rgb_array = np.clip(rgb_array * 255, 0, 255).astype(np.uint8)
|
| 69 |
|
| 70 |
+
# Normalize NDVI (4th channel) from [-1, 1] to [0, 255]
|
| 71 |
+
ndvi_array = img_array[:, :, 3]
|
| 72 |
+
ndvi_normalized = ((ndvi_array + 1) * 127.5).astype(np.uint8)
|
| 73 |
+
|
| 74 |
+
# Recombine into 4-channel uint8 array
|
| 75 |
+
img_array = np.zeros((img_array.shape[0], img_array.shape[1], 4), dtype=np.uint8)
|
| 76 |
+
img_array[:, :, :3] = rgb_array
|
| 77 |
+
img_array[:, :, 3] = ndvi_normalized
|
| 78 |
+
|
| 79 |
+
# Save normalized image to temporary file
|
| 80 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
|
| 81 |
+
temp_path = tmp_file.name
|
| 82 |
+
tifffile.imwrite(
|
| 83 |
+
temp_path,
|
| 84 |
+
img_array,
|
| 85 |
+
photometric='rgb',
|
| 86 |
+
compress='lzw',
|
| 87 |
+
metadata={'axes': 'YXC', 'resolution': (1, 1)} # DPI=1
|
| 88 |
+
)
|
| 89 |
+
image_path = temp_path
|
| 90 |
+
|
| 91 |
+
# Run YOLO prediction
|
| 92 |
+
results = yolo_model([image_path], conf=conf)
|
| 93 |
+
|
| 94 |
+
# Clean up temporary file if created
|
| 95 |
+
if 'temp_path' in locals() and os.path.exists(temp_path):
|
| 96 |
+
os.unlink(temp_path)
|
| 97 |
+
|
| 98 |
+
return results[0] # Return first result
|
| 99 |
|
| 100 |
+
except Exception as e:
|
| 101 |
+
raise ValueError(f"Error processing image: {str(e)}")
|
| 102 |
|
| 103 |
def create_4channel_tiff(rgb_array, ndvi_array, output_path):
|
| 104 |
"""
|
|
|
|
| 111 |
"""
|
| 112 |
height, width = rgb_array.shape[:2]
|
| 113 |
|
| 114 |
+
# Ensure RGB is in uint8 format
|
| 115 |
if rgb_array.dtype != np.uint8:
|
| 116 |
if rgb_array.max() <= 1.0:
|
| 117 |
rgb_normalized = (rgb_array * 255).astype(np.uint8)
|
|
|
|
| 130 |
four_channel[:, :, 2] = rgb_normalized[:, :, 2] # Blue
|
| 131 |
four_channel[:, :, 3] = ndvi_normalized # NDVI
|
| 132 |
|
| 133 |
+
# Save using tifffile with explicit 32-bit compatibility and DPI=1
|
| 134 |
tifffile.imwrite(
|
| 135 |
+
output_path,
|
| 136 |
+
four_channel,
|
| 137 |
photometric='rgb',
|
| 138 |
compress='lzw',
|
| 139 |
+
metadata={'axes': 'YXC', 'resolution': (1, 1)}, # DPI=1
|
| 140 |
+
bitspersample=8 # Explicitly set to 8-bit per channel
|
| 141 |
)
|
| 142 |
|
| 143 |
def load_4channel_tiff(image_path):
|
|
|
|
| 247 |
except Exception as e2:
|
| 248 |
try:
|
| 249 |
# Method 3: Fall back to PIL for standard formats
|
| 250 |
+
img = Image.PIL(image_path)
|
| 251 |
if img.mode != 'RGB':
|
| 252 |
img = img.convert('RGB')
|
| 253 |
rgb_array = np.array(img)
|
|
|
|
| 260 |
|
| 261 |
# Ensure RGB is in correct format and range
|
| 262 |
if rgb_array.dtype == np.uint8:
|
|
|
|
| 263 |
rgb_float = rgb_array.astype(np.float32) / 255.0
|
| 264 |
else:
|
|
|
|
| 265 |
if rgb_array.max() > 1.0:
|
| 266 |
+
rgb_float = rgb_array / rgb_array.max()
|
| 267 |
else:
|
| 268 |
rgb_float = rgb_array
|
| 269 |
rgb_array = (rgb_float * 255).astype(np.uint8)
|
|
|
|
| 281 |
|
| 282 |
# Verify the created file can be read
|
| 283 |
try:
|
| 284 |
+
test_array = tifffile.imread(temp_4ch_path)
|
| 285 |
+
if len(test_array.shape) == 3 and (test_array.shape[0] == 4 or test_array.shape[2] == 4):
|
| 286 |
+
channels = 4
|
| 287 |
else:
|
| 288 |
+
channels = test_array.shape[2] if len(test_array.shape) == 3 else 1
|
|
|
|
|
|
|
| 289 |
if channels != 4:
|
| 290 |
raise ValueError(f"Created TIFF has {channels} channels instead of 4")
|
| 291 |
|