Update processing.py
Browse files- processing.py +76 -42
processing.py
CHANGED
|
@@ -2,18 +2,20 @@
|
|
| 2 |
Image processing pipeline for SUB-SENTINEL.
|
| 3 |
|
| 4 |
Provides three functions:
|
| 5 |
-
enhance_image(raw_bytes)
|
| 6 |
-
run_detection(image_array)
|
| 7 |
-
build_heatmap(image_array)
|
| 8 |
|
| 9 |
All heavy-weight model paths gracefully fall back to CPU-friendly alternatives
|
| 10 |
-
when model weights are absent.
|
|
|
|
| 11 |
"""
|
| 12 |
|
| 13 |
-
import
|
| 14 |
import io
|
|
|
|
| 15 |
import logging
|
| 16 |
-
from typing import Optional
|
| 17 |
|
| 18 |
import cv2
|
| 19 |
import numpy as np
|
|
@@ -21,6 +23,13 @@ from PIL import Image
|
|
| 21 |
from skimage.metrics import structural_similarity as ssim
|
| 22 |
|
| 23 |
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# ---------------------------------------------------------------------------
|
| 26 |
# Maritime label mapping for YOLOv8 COCO classes
|
|
@@ -35,13 +44,15 @@ _LABEL_MAP: dict[str, str] = {
|
|
| 35 |
}
|
| 36 |
|
| 37 |
|
|
|
|
| 38 |
def _array_to_base64(img_array: np.ndarray, fmt: str = "JPEG") -> str:
|
| 39 |
"""Convert a uint8 numpy array (H×W×C, RGB) to a base-64 data-URI string."""
|
| 40 |
pil_img = Image.fromarray(img_array.astype(np.uint8))
|
| 41 |
buf = io.BytesIO()
|
| 42 |
-
|
|
|
|
| 43 |
encoded = base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 44 |
-
mime = "image/jpeg" if
|
| 45 |
return f"data:{mime};base64,{encoded}"
|
| 46 |
|
| 47 |
|
|
@@ -57,8 +68,6 @@ def _bytes_to_array(raw_bytes: bytes) -> np.ndarray:
|
|
| 57 |
# ---------------------------------------------------------------------------
|
| 58 |
# 1. Underwater image enhancement
|
| 59 |
# ---------------------------------------------------------------------------
|
| 60 |
-
|
| 61 |
-
|
| 62 |
def _clahe_enhance(rgb: np.ndarray) -> np.ndarray:
|
| 63 |
"""
|
| 64 |
CPU-friendly underwater enhancement using CLAHE on LAB colour space.
|
|
@@ -82,7 +91,6 @@ def _funiegan_enhance(rgb: np.ndarray) -> Optional[np.ndarray]:
|
|
| 82 |
"""
|
| 83 |
weights_path = "weights/funiegan.onnx"
|
| 84 |
try:
|
| 85 |
-
import os
|
| 86 |
if not os.path.exists(weights_path):
|
| 87 |
return None
|
| 88 |
net = cv2.dnn.readNetFromONNX(weights_path)
|
|
@@ -92,10 +100,11 @@ def _funiegan_enhance(rgb: np.ndarray) -> Optional[np.ndarray]:
|
|
| 92 |
blob = cv2.dnn.blobFromImage(resized)
|
| 93 |
net.setInput(blob)
|
| 94 |
out = net.forward()
|
|
|
|
| 95 |
out_img = ((out[0].transpose(1, 2, 0) + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
|
| 96 |
return cv2.resize(out_img, (w, h))
|
| 97 |
except Exception as exc:
|
| 98 |
-
logger.warning("FUnIE-GAN inference failed (%s);
|
| 99 |
return None
|
| 100 |
|
| 101 |
|
|
@@ -115,68 +124,93 @@ def enhance_image(raw_bytes: bytes) -> tuple[str, np.ndarray]:
|
|
| 115 |
|
| 116 |
|
| 117 |
# ---------------------------------------------------------------------------
|
| 118 |
-
# 2. Object detection (
|
| 119 |
# ---------------------------------------------------------------------------
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
def run_detection(rgb: np.ndarray) -> list[dict]:
|
| 123 |
"""
|
| 124 |
-
Run
|
|
|
|
| 125 |
|
| 126 |
Returns a list of detection dicts:
|
| 127 |
{class, mapped_label, confidence, bbox: [x1, y1, x2, y2]}
|
| 128 |
"""
|
| 129 |
try:
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
results = model(rgb, verbose=False)
|
| 133 |
except Exception as exc:
|
| 134 |
-
logger.warning("
|
| 135 |
return []
|
| 136 |
|
| 137 |
-
detections = []
|
| 138 |
for result in results:
|
| 139 |
-
|
|
|
|
| 140 |
continue
|
| 141 |
-
for box in
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
"class": cls_name,
|
| 149 |
"mapped_label": _LABEL_MAP.get(cls_name, cls_name),
|
| 150 |
"confidence": round(conf, 4),
|
| 151 |
"bbox": [round(x1), round(y1), round(x2), round(y2)],
|
| 152 |
-
}
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
| 154 |
return detections
|
| 155 |
|
| 156 |
|
| 157 |
# ---------------------------------------------------------------------------
|
| 158 |
# 3. SSIM-based forensic heatmap
|
| 159 |
# ---------------------------------------------------------------------------
|
| 160 |
-
|
| 161 |
-
|
| 162 |
def build_heatmap(rgb: np.ndarray) -> str:
|
| 163 |
"""
|
| 164 |
Generate a forensic heatmap by comparing the original image against a
|
| 165 |
-
Gaussian-blurred reference.
|
| 166 |
-
|
| 167 |
-
Returns a base64-encoded PNG heatmap.
|
| 168 |
"""
|
| 169 |
gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY)
|
| 170 |
-
# Reference: gently blurred version of the same frame
|
| 171 |
blurred = cv2.GaussianBlur(gray, (15, 15), 0)
|
| 172 |
|
| 173 |
-
# Compute SSIM score map
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
# Normalise to [0, 255]
|
| 177 |
-
ssim_norm = ((ssim_map + 1.0) / 2.0 * 255).clip(0, 255).astype(np.uint8)
|
| 178 |
|
| 179 |
-
# Map to BGR: low similarity
|
| 180 |
colormap = cv2.COLORMAP_RdYlGn if hasattr(cv2, "COLORMAP_RdYlGn") else cv2.COLORMAP_JET
|
| 181 |
heatmap_bgr = cv2.applyColorMap(ssim_norm, colormap)
|
| 182 |
|
|
@@ -185,4 +219,4 @@ def build_heatmap(rgb: np.ndarray) -> str:
|
|
| 185 |
overlay = cv2.addWeighted(rgb_bgr, 0.55, heatmap_bgr, 0.45, 0)
|
| 186 |
overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
|
| 187 |
|
| 188 |
-
return _array_to_base64(overlay_rgb, fmt="PNG")
|
|
|
|
| 2 |
Image processing pipeline for SUB-SENTINEL.
|
| 3 |
|
| 4 |
Provides three functions:
|
| 5 |
+
enhance_image(raw_bytes) -> (base64_str, numpy_array)
|
| 6 |
+
run_detection(image_array) -> list[dict]
|
| 7 |
+
build_heatmap(image_array) -> base64_str
|
| 8 |
|
| 9 |
All heavy-weight model paths gracefully fall back to CPU-friendly alternatives
|
| 10 |
+
when model weights are absent. Use the environment variable DETECTION_MODEL
|
| 11 |
+
to override the default detection model (e.g. "yolov8m.pt" or a local path).
|
| 12 |
"""
|
| 13 |
|
| 14 |
+
import os
|
| 15 |
import io
|
| 16 |
+
import base64
|
| 17 |
import logging
|
| 18 |
+
from typing import Optional, List, Dict
|
| 19 |
|
| 20 |
import cv2
|
| 21 |
import numpy as np
|
|
|
|
| 23 |
from skimage.metrics import structural_similarity as ssim
|
| 24 |
|
| 25 |
logger = logging.getLogger(__name__)
|
| 26 |
+
logger.addHandler(logging.NullHandler())
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Default detection model (change via env var DETECTION_MODEL if needed)
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
# NOTE: default changed to yolov8m for improved accuracy.
|
| 32 |
+
DEFAULT_DETECTION_MODEL = os.getenv("DETECTION_MODEL", "yolov8m.pt")
|
| 33 |
|
| 34 |
# ---------------------------------------------------------------------------
|
| 35 |
# Maritime label mapping for YOLOv8 COCO classes
|
|
|
|
| 44 |
}
|
| 45 |
|
| 46 |
|
| 47 |
+
# --------------------------- utilities -------------------------------------
|
| 48 |
def _array_to_base64(img_array: np.ndarray, fmt: str = "JPEG") -> str:
|
| 49 |
"""Convert a uint8 numpy array (H×W×C, RGB) to a base-64 data-URI string."""
|
| 50 |
pil_img = Image.fromarray(img_array.astype(np.uint8))
|
| 51 |
buf = io.BytesIO()
|
| 52 |
+
fmt_upper = fmt.upper()
|
| 53 |
+
pil_img.save(buf, format=fmt_upper, quality=90)
|
| 54 |
encoded = base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 55 |
+
mime = "image/jpeg" if fmt_upper == "JPEG" else "image/png"
|
| 56 |
return f"data:{mime};base64,{encoded}"
|
| 57 |
|
| 58 |
|
|
|
|
| 68 |
# ---------------------------------------------------------------------------
|
| 69 |
# 1. Underwater image enhancement
|
| 70 |
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
| 71 |
def _clahe_enhance(rgb: np.ndarray) -> np.ndarray:
|
| 72 |
"""
|
| 73 |
CPU-friendly underwater enhancement using CLAHE on LAB colour space.
|
|
|
|
| 91 |
"""
|
| 92 |
weights_path = "weights/funiegan.onnx"
|
| 93 |
try:
|
|
|
|
| 94 |
if not os.path.exists(weights_path):
|
| 95 |
return None
|
| 96 |
net = cv2.dnn.readNetFromONNX(weights_path)
|
|
|
|
| 100 |
blob = cv2.dnn.blobFromImage(resized)
|
| 101 |
net.setInput(blob)
|
| 102 |
out = net.forward()
|
| 103 |
+
# out shape may be (1, C, H, W)
|
| 104 |
out_img = ((out[0].transpose(1, 2, 0) + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
|
| 105 |
return cv2.resize(out_img, (w, h))
|
| 106 |
except Exception as exc:
|
| 107 |
+
logger.warning("FUnIE-GAN inference failed (%s); falling back to CLAHE.", exc)
|
| 108 |
return None
|
| 109 |
|
| 110 |
|
|
|
|
| 124 |
|
| 125 |
|
| 126 |
# ---------------------------------------------------------------------------
|
| 127 |
+
# 2. Object detection (YOLOv8 family; default is yolov8m.pt)
|
| 128 |
# ---------------------------------------------------------------------------
|
| 129 |
+
def run_detection(rgb: np.ndarray, conf_thresh: float = 0.30) -> List[dict]:
|
|
|
|
|
|
|
| 130 |
"""
|
| 131 |
+
Run YOLO detection (model chosen by DETECTION_MODEL env var or default)
|
| 132 |
+
and map labels to maritime terminology.
|
| 133 |
|
| 134 |
Returns a list of detection dicts:
|
| 135 |
{class, mapped_label, confidence, bbox: [x1, y1, x2, y2]}
|
| 136 |
"""
|
| 137 |
try:
|
| 138 |
+
# Lazy import to avoid heavy dependency cost at module import time
|
| 139 |
+
from ultralytics import YOLO # type: ignore
|
| 140 |
+
except Exception as exc:
|
| 141 |
+
logger.warning("ultralytics package not available (%s); detection disabled.", exc)
|
| 142 |
+
return []
|
| 143 |
+
|
| 144 |
+
model_path = os.getenv("DETECTION_MODEL", DEFAULT_DETECTION_MODEL)
|
| 145 |
+
try:
|
| 146 |
+
model = YOLO(model_path)
|
| 147 |
+
except Exception as exc:
|
| 148 |
+
logger.warning("Failed to load detection model '%s' (%s). Returning empty.", model_path, exc)
|
| 149 |
+
return []
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
# Model accepts numpy image (RGB) directly
|
| 153 |
results = model(rgb, verbose=False)
|
| 154 |
except Exception as exc:
|
| 155 |
+
logger.warning("Model inference failed (%s). Returning empty.", exc)
|
| 156 |
return []
|
| 157 |
|
| 158 |
+
detections: List[dict] = []
|
| 159 |
for result in results:
|
| 160 |
+
boxes = getattr(result, "boxes", None)
|
| 161 |
+
if boxes is None:
|
| 162 |
continue
|
| 163 |
+
for box in boxes:
|
| 164 |
+
try:
|
| 165 |
+
# Defensive extraction: the ultralytics API returns tensors/arrays
|
| 166 |
+
conf = float(box.conf[0]) if hasattr(box.conf, "__len__") else float(box.conf)
|
| 167 |
+
if conf < conf_thresh:
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
cls_id = int(box.cls[0]) if hasattr(box.cls, "__len__") else int(box.cls)
|
| 171 |
+
cls_name = model.names.get(cls_id, str(cls_id)) if hasattr(model, "names") else str(cls_id)
|
| 172 |
+
|
| 173 |
+
xyxy = box.xyxy[0] if hasattr(box.xyxy, "__len__") and len(box.xyxy) > 0 else None
|
| 174 |
+
if xyxy is None:
|
| 175 |
+
continue
|
| 176 |
+
x1, y1, x2, y2 = (float(v) for v in xyxy)
|
| 177 |
+
detections.append({
|
| 178 |
"class": cls_name,
|
| 179 |
"mapped_label": _LABEL_MAP.get(cls_name, cls_name),
|
| 180 |
"confidence": round(conf, 4),
|
| 181 |
"bbox": [round(x1), round(y1), round(x2), round(y2)],
|
| 182 |
+
})
|
| 183 |
+
except Exception as exc:
|
| 184 |
+
logger.debug("Skipping box due to extraction error: %s", exc)
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
return detections
|
| 188 |
|
| 189 |
|
| 190 |
# ---------------------------------------------------------------------------
|
| 191 |
# 3. SSIM-based forensic heatmap
|
| 192 |
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
| 193 |
def build_heatmap(rgb: np.ndarray) -> str:
|
| 194 |
"""
|
| 195 |
Generate a forensic heatmap by comparing the original image against a
|
| 196 |
+
Gaussian-blurred reference. High SSIM -> green; low SSIM -> red.
|
| 197 |
+
Returns a base64-encoded PNG heatmap (data URI).
|
|
|
|
| 198 |
"""
|
| 199 |
gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY)
|
|
|
|
| 200 |
blurred = cv2.GaussianBlur(gray, (15, 15), 0)
|
| 201 |
|
| 202 |
+
# Compute SSIM score map; fallback to simple difference if it fails
|
| 203 |
+
try:
|
| 204 |
+
_, ssim_map = ssim(gray, blurred, full=True, data_range=255)
|
| 205 |
+
except Exception as exc:
|
| 206 |
+
logger.warning("SSIM computation failed (%s); falling back to absdiff.", exc)
|
| 207 |
+
diff = cv2.absdiff(gray, blurred).astype(np.float32)
|
| 208 |
+
ssim_map = 1.0 - (diff / 255.0)
|
| 209 |
|
| 210 |
# Normalise to [0, 255]
|
| 211 |
+
ssim_norm = ((ssim_map + 1.0) / 2.0 * 255.0).clip(0, 255).astype(np.uint8)
|
| 212 |
|
| 213 |
+
# Map to BGR: low similarity -> red, high -> green
|
| 214 |
colormap = cv2.COLORMAP_RdYlGn if hasattr(cv2, "COLORMAP_RdYlGn") else cv2.COLORMAP_JET
|
| 215 |
heatmap_bgr = cv2.applyColorMap(ssim_norm, colormap)
|
| 216 |
|
|
|
|
| 219 |
overlay = cv2.addWeighted(rgb_bgr, 0.55, heatmap_bgr, 0.45, 0)
|
| 220 |
overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
|
| 221 |
|
| 222 |
+
return _array_to_base64(overlay_rgb, fmt="PNG")
|