Spaces:
Sleeping
Sleeping
Update yolo_predictor.py
Browse files- yolo_predictor.py +36 -42
yolo_predictor.py
CHANGED
|
@@ -32,71 +32,65 @@ def predict_ndvi_from_rgb(ndvi_model, rgb_array):
|
|
| 32 |
return ndvi_pred
|
| 33 |
|
| 34 |
def predict_yolo(yolo_model, image_path, conf=0.001):
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
Args:
|
| 39 |
-
yolo_model: Loaded YOLO model
|
| 40 |
-
image_path: Path to 4-channel TIFF image
|
| 41 |
-
conf: Confidence threshold
|
| 42 |
-
|
| 43 |
-
Returns:
|
| 44 |
-
results: YOLO results object
|
| 45 |
-
"""
|
| 46 |
-
# Verify the image has 4 channels before prediction
|
| 47 |
try:
|
| 48 |
-
# Use tifffile for 32-bit TIFF support
|
| 49 |
img_array = tifffile.imread(image_path)
|
| 50 |
-
|
| 51 |
-
|
| 52 |
if len(img_array.shape) == 3:
|
| 53 |
if img_array.shape[0] == 4:
|
| 54 |
-
# Shape is (4, H, W) - transpose to (H, W, 4)
|
| 55 |
img_array = np.transpose(img_array, (1, 2, 0))
|
|
|
|
| 56 |
elif img_array.shape[2] != 4:
|
| 57 |
-
raise ValueError(f"
|
| 58 |
else:
|
| 59 |
-
raise ValueError(f"Unexpected image shape: {img_array.shape}")
|
| 60 |
-
|
| 61 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
if img_array.dtype != np.uint8:
|
| 63 |
-
|
| 64 |
rgb_array = img_array[:, :, :3]
|
|
|
|
|
|
|
|
|
|
| 65 |
if rgb_array.max() > 1.0:
|
| 66 |
rgb_array = np.clip(rgb_array / rgb_array.max() * 255, 0, 255).astype(np.uint8)
|
| 67 |
else:
|
| 68 |
rgb_array = np.clip(rgb_array * 255, 0, 255).astype(np.uint8)
|
| 69 |
-
|
| 70 |
-
# Normalize NDVI
|
| 71 |
-
ndvi_array = img_array[:, :, 3]
|
| 72 |
ndvi_normalized = ((ndvi_array + 1) * 127.5).astype(np.uint8)
|
| 73 |
-
|
| 74 |
-
# Recombine into 4-channel uint8 array
|
| 75 |
img_array = np.zeros((img_array.shape[0], img_array.shape[1], 4), dtype=np.uint8)
|
| 76 |
img_array[:, :, :3] = rgb_array
|
| 77 |
img_array[:, :, 3] = ndvi_normalized
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
| 80 |
with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
|
| 81 |
temp_path = tmp_file.name
|
| 82 |
-
tifffile.imwrite(
|
| 83 |
-
temp_path,
|
| 84 |
-
img_array,
|
| 85 |
-
photometric='rgb',
|
| 86 |
-
compress='lzw',
|
| 87 |
-
metadata={'axes': 'YXC', 'resolution': (1, 1)} # DPI=1
|
| 88 |
-
)
|
| 89 |
image_path = temp_path
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
results = yolo_model([image_path], conf=conf)
|
| 93 |
-
|
| 94 |
-
# Clean up temporary file if created
|
| 95 |
if 'temp_path' in locals() and os.path.exists(temp_path):
|
| 96 |
os.unlink(temp_path)
|
| 97 |
-
|
| 98 |
-
return results[0]
|
| 99 |
-
|
| 100 |
except Exception as e:
|
| 101 |
raise ValueError(f"Error processing image: {str(e)}")
|
| 102 |
|
|
|
|
| 32 |
return ndvi_pred
|
| 33 |
|
| 34 |
def predict_yolo(yolo_model, image_path, conf=0.001):
|
| 35 |
+
import tifffile
|
| 36 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
try:
|
|
|
|
| 38 |
img_array = tifffile.imread(image_path)
|
| 39 |
+
print(f"[DEBUG] Loaded image shape: {img_array.shape}, dtype: {img_array.dtype}")
|
| 40 |
+
|
| 41 |
if len(img_array.shape) == 3:
|
| 42 |
if img_array.shape[0] == 4:
|
|
|
|
| 43 |
img_array = np.transpose(img_array, (1, 2, 0))
|
| 44 |
+
print(f"[DEBUG] Transposed image shape to (H,W,C): {img_array.shape}")
|
| 45 |
elif img_array.shape[2] != 4:
|
| 46 |
+
raise ValueError(f"[ERROR] Expected 4 channels, got {img_array.shape[2]}")
|
| 47 |
else:
|
| 48 |
+
raise ValueError(f"[ERROR] Unexpected image shape: {img_array.shape}")
|
| 49 |
+
|
| 50 |
+
# Confirm channel count
|
| 51 |
+
if img_array.shape[2] != 4:
|
| 52 |
+
raise ValueError(f"[ERROR] After transpose, still not 4 channels: got {img_array.shape[2]}")
|
| 53 |
+
|
| 54 |
+
print(f"[DEBUG] Image dtype before normalization: {img_array.dtype}")
|
| 55 |
+
|
| 56 |
if img_array.dtype != np.uint8:
|
| 57 |
+
print(f"[DEBUG] Converting image to uint8")
|
| 58 |
rgb_array = img_array[:, :, :3]
|
| 59 |
+
ndvi_array = img_array[:, :, 3]
|
| 60 |
+
|
| 61 |
+
# Normalize RGB
|
| 62 |
if rgb_array.max() > 1.0:
|
| 63 |
rgb_array = np.clip(rgb_array / rgb_array.max() * 255, 0, 255).astype(np.uint8)
|
| 64 |
else:
|
| 65 |
rgb_array = np.clip(rgb_array * 255, 0, 255).astype(np.uint8)
|
| 66 |
+
|
| 67 |
+
# Normalize NDVI
|
|
|
|
| 68 |
ndvi_normalized = ((ndvi_array + 1) * 127.5).astype(np.uint8)
|
| 69 |
+
|
|
|
|
| 70 |
img_array = np.zeros((img_array.shape[0], img_array.shape[1], 4), dtype=np.uint8)
|
| 71 |
img_array[:, :, :3] = rgb_array
|
| 72 |
img_array[:, :, 3] = ndvi_normalized
|
| 73 |
+
|
| 74 |
+
print(f"[DEBUG] Image converted to uint8 with shape: {img_array.shape}")
|
| 75 |
+
|
| 76 |
+
# Save normalized version to temp file
|
| 77 |
with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
|
| 78 |
temp_path = tmp_file.name
|
| 79 |
+
tifffile.imwrite(temp_path, img_array)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
image_path = temp_path
|
| 81 |
+
|
| 82 |
+
print(f"[DEBUG] Final image ready for YOLO, path: {image_path}")
|
| 83 |
+
|
| 84 |
+
# Final safety check
|
| 85 |
+
assert img_array.shape[2] == 4, "[FATAL] Final image does not have 4 channels."
|
| 86 |
+
|
| 87 |
results = yolo_model([image_path], conf=conf)
|
| 88 |
+
|
|
|
|
| 89 |
if 'temp_path' in locals() and os.path.exists(temp_path):
|
| 90 |
os.unlink(temp_path)
|
| 91 |
+
|
| 92 |
+
return results[0]
|
| 93 |
+
|
| 94 |
except Exception as e:
|
| 95 |
raise ValueError(f"Error processing image: {str(e)}")
|
| 96 |
|