Arko007 commited on
Commit
4f41596
·
verified ·
1 Parent(s): 3d2e524

Update processing.py

Browse files
Files changed (1) hide show
  1. processing.py +76 -42
processing.py CHANGED
@@ -2,18 +2,20 @@
2
  Image processing pipeline for SUB-SENTINEL.
3
 
4
  Provides three functions:
5
- enhance_image(raw_bytes) (base64_str, numpy_array)
6
- run_detection(image_array) list[dict]
7
- build_heatmap(image_array) base64_str
8
 
9
  All heavy-weight model paths gracefully fall back to CPU-friendly alternatives
10
- when model weights are absent.
 
11
  """
12
 
13
- import base64
14
  import io
 
15
  import logging
16
- from typing import Optional
17
 
18
  import cv2
19
  import numpy as np
@@ -21,6 +23,13 @@ from PIL import Image
21
  from skimage.metrics import structural_similarity as ssim
22
 
23
  logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
24
 
25
  # ---------------------------------------------------------------------------
26
  # Maritime label mapping for YOLOv8 COCO classes
@@ -35,13 +44,15 @@ _LABEL_MAP: dict[str, str] = {
35
  }
36
 
37
 
 
38
  def _array_to_base64(img_array: np.ndarray, fmt: str = "JPEG") -> str:
39
  """Convert a uint8 numpy array (H×W×C, RGB) to a base-64 data-URI string."""
40
  pil_img = Image.fromarray(img_array.astype(np.uint8))
41
  buf = io.BytesIO()
42
- pil_img.save(buf, format=fmt, quality=90)
 
43
  encoded = base64.b64encode(buf.getvalue()).decode("utf-8")
44
- mime = "image/jpeg" if fmt == "JPEG" else "image/png"
45
  return f"data:{mime};base64,{encoded}"
46
 
47
 
@@ -57,8 +68,6 @@ def _bytes_to_array(raw_bytes: bytes) -> np.ndarray:
57
  # ---------------------------------------------------------------------------
58
  # 1. Underwater image enhancement
59
  # ---------------------------------------------------------------------------
60
-
61
-
62
  def _clahe_enhance(rgb: np.ndarray) -> np.ndarray:
63
  """
64
  CPU-friendly underwater enhancement using CLAHE on LAB colour space.
@@ -82,7 +91,6 @@ def _funiegan_enhance(rgb: np.ndarray) -> Optional[np.ndarray]:
82
  """
83
  weights_path = "weights/funiegan.onnx"
84
  try:
85
- import os
86
  if not os.path.exists(weights_path):
87
  return None
88
  net = cv2.dnn.readNetFromONNX(weights_path)
@@ -92,10 +100,11 @@ def _funiegan_enhance(rgb: np.ndarray) -> Optional[np.ndarray]:
92
  blob = cv2.dnn.blobFromImage(resized)
93
  net.setInput(blob)
94
  out = net.forward()
 
95
  out_img = ((out[0].transpose(1, 2, 0) + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
96
  return cv2.resize(out_img, (w, h))
97
  except Exception as exc:
98
- logger.warning("FUnIE-GAN inference failed (%s); using CLAHE fallback.", exc)
99
  return None
100
 
101
 
@@ -115,68 +124,93 @@ def enhance_image(raw_bytes: bytes) -> tuple[str, np.ndarray]:
115
 
116
 
117
  # ---------------------------------------------------------------------------
118
- # 2. Object detection (YOLOv8n)
119
  # ---------------------------------------------------------------------------
120
-
121
-
122
- def run_detection(rgb: np.ndarray) -> list[dict]:
123
  """
124
- Run YOLOv8n COCO detection and map labels to maritime terminology.
 
125
 
126
  Returns a list of detection dicts:
127
  {class, mapped_label, confidence, bbox: [x1, y1, x2, y2]}
128
  """
129
  try:
130
- from ultralytics import YOLO # lazy import large package
131
- model = YOLO("yolov8n.pt") # downloads automatically on first run
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  results = model(rgb, verbose=False)
133
  except Exception as exc:
134
- logger.warning("YOLOv8n detection failed (%s); returning empty detections.", exc)
135
  return []
136
 
137
- detections = []
138
  for result in results:
139
- if result.boxes is None:
 
140
  continue
141
- for box in result.boxes:
142
- cls_id = int(box.cls[0])
143
- cls_name = model.names.get(cls_id, str(cls_id))
144
- conf = float(box.conf[0])
145
- x1, y1, x2, y2 = (float(v) for v in box.xyxy[0])
146
- detections.append(
147
- {
 
 
 
 
 
 
 
 
148
  "class": cls_name,
149
  "mapped_label": _LABEL_MAP.get(cls_name, cls_name),
150
  "confidence": round(conf, 4),
151
  "bbox": [round(x1), round(y1), round(x2), round(y2)],
152
- }
153
- )
 
 
 
154
  return detections
155
 
156
 
157
  # ---------------------------------------------------------------------------
158
  # 3. SSIM-based forensic heatmap
159
  # ---------------------------------------------------------------------------
160
-
161
-
162
  def build_heatmap(rgb: np.ndarray) -> str:
163
  """
164
  Generate a forensic heatmap by comparing the original image against a
165
- Gaussian-blurred reference. High SSIM green; low SSIM red.
166
-
167
- Returns a base64-encoded PNG heatmap.
168
  """
169
  gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY)
170
- # Reference: gently blurred version of the same frame
171
  blurred = cv2.GaussianBlur(gray, (15, 15), 0)
172
 
173
- # Compute SSIM score map (window-level scores)
174
- _, ssim_map = ssim(gray, blurred, full=True, data_range=255)
 
 
 
 
 
175
 
176
  # Normalise to [0, 255]
177
- ssim_norm = ((ssim_map + 1.0) / 2.0 * 255).clip(0, 255).astype(np.uint8)
178
 
179
- # Map to BGR: low similarity red (forensic interest), high green
180
  colormap = cv2.COLORMAP_RdYlGn if hasattr(cv2, "COLORMAP_RdYlGn") else cv2.COLORMAP_JET
181
  heatmap_bgr = cv2.applyColorMap(ssim_norm, colormap)
182
 
@@ -185,4 +219,4 @@ def build_heatmap(rgb: np.ndarray) -> str:
185
  overlay = cv2.addWeighted(rgb_bgr, 0.55, heatmap_bgr, 0.45, 0)
186
  overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
187
 
188
- return _array_to_base64(overlay_rgb, fmt="PNG")
 
2
  Image processing pipeline for SUB-SENTINEL.
3
 
4
  Provides three functions:
5
+ enhance_image(raw_bytes) -> (base64_str, numpy_array)
6
+ run_detection(image_array) -> list[dict]
7
+ build_heatmap(image_array) -> base64_str
8
 
9
  All heavy-weight model paths gracefully fall back to CPU-friendly alternatives
10
+ when model weights are absent. Use the environment variable DETECTION_MODEL
11
+ to override the default detection model (e.g. "yolov8m.pt" or a local path).
12
  """
13
 
14
+ import os
15
  import io
16
+ import base64
17
  import logging
18
+ from typing import Optional, List, Dict
19
 
20
  import cv2
21
  import numpy as np
 
23
  from skimage.metrics import structural_similarity as ssim
24
 
25
  logger = logging.getLogger(__name__)
26
+ logger.addHandler(logging.NullHandler())
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Default detection model (change via env var DETECTION_MODEL if needed)
30
+ # ---------------------------------------------------------------------------
31
+ # NOTE: default changed to yolov8m for improved accuracy.
32
+ DEFAULT_DETECTION_MODEL = os.getenv("DETECTION_MODEL", "yolov8m.pt")
33
 
34
  # ---------------------------------------------------------------------------
35
  # Maritime label mapping for YOLOv8 COCO classes
 
44
  }
45
 
46
 
47
+ # --------------------------- utilities -------------------------------------
48
  def _array_to_base64(img_array: np.ndarray, fmt: str = "JPEG") -> str:
49
  """Convert a uint8 numpy array (H×W×C, RGB) to a base-64 data-URI string."""
50
  pil_img = Image.fromarray(img_array.astype(np.uint8))
51
  buf = io.BytesIO()
52
+ fmt_upper = fmt.upper()
53
+ pil_img.save(buf, format=fmt_upper, quality=90)
54
  encoded = base64.b64encode(buf.getvalue()).decode("utf-8")
55
+ mime = "image/jpeg" if fmt_upper == "JPEG" else "image/png"
56
  return f"data:{mime};base64,{encoded}"
57
 
58
 
 
68
  # ---------------------------------------------------------------------------
69
  # 1. Underwater image enhancement
70
  # ---------------------------------------------------------------------------
 
 
71
  def _clahe_enhance(rgb: np.ndarray) -> np.ndarray:
72
  """
73
  CPU-friendly underwater enhancement using CLAHE on LAB colour space.
 
91
  """
92
  weights_path = "weights/funiegan.onnx"
93
  try:
 
94
  if not os.path.exists(weights_path):
95
  return None
96
  net = cv2.dnn.readNetFromONNX(weights_path)
 
100
  blob = cv2.dnn.blobFromImage(resized)
101
  net.setInput(blob)
102
  out = net.forward()
103
+ # out shape may be (1, C, H, W)
104
  out_img = ((out[0].transpose(1, 2, 0) + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
105
  return cv2.resize(out_img, (w, h))
106
  except Exception as exc:
107
+ logger.warning("FUnIE-GAN inference failed (%s); falling back to CLAHE.", exc)
108
  return None
109
 
110
 
 
124
 
125
 
126
  # ---------------------------------------------------------------------------
127
+ # 2. Object detection (YOLOv8 family; default is yolov8m.pt)
128
  # ---------------------------------------------------------------------------
129
+ def run_detection(rgb: np.ndarray, conf_thresh: float = 0.30) -> List[dict]:
 
 
130
  """
131
+ Run YOLO detection (model chosen by DETECTION_MODEL env var or default)
132
+ and map labels to maritime terminology.
133
 
134
  Returns a list of detection dicts:
135
  {class, mapped_label, confidence, bbox: [x1, y1, x2, y2]}
136
  """
137
  try:
138
+ # Lazy import to avoid heavy dependency cost at module import time
139
+ from ultralytics import YOLO # type: ignore
140
+ except Exception as exc:
141
+ logger.warning("ultralytics package not available (%s); detection disabled.", exc)
142
+ return []
143
+
144
+ model_path = os.getenv("DETECTION_MODEL", DEFAULT_DETECTION_MODEL)
145
+ try:
146
+ model = YOLO(model_path)
147
+ except Exception as exc:
148
+ logger.warning("Failed to load detection model '%s' (%s). Returning empty.", model_path, exc)
149
+ return []
150
+
151
+ try:
152
+ # Model accepts numpy image (RGB) directly
153
  results = model(rgb, verbose=False)
154
  except Exception as exc:
155
+ logger.warning("Model inference failed (%s). Returning empty.", exc)
156
  return []
157
 
158
+ detections: List[dict] = []
159
  for result in results:
160
+ boxes = getattr(result, "boxes", None)
161
+ if boxes is None:
162
  continue
163
+ for box in boxes:
164
+ try:
165
+ # Defensive extraction: the ultralytics API returns tensors/arrays
166
+ conf = float(box.conf[0]) if hasattr(box.conf, "__len__") else float(box.conf)
167
+ if conf < conf_thresh:
168
+ continue
169
+
170
+ cls_id = int(box.cls[0]) if hasattr(box.cls, "__len__") else int(box.cls)
171
+ cls_name = model.names.get(cls_id, str(cls_id)) if hasattr(model, "names") else str(cls_id)
172
+
173
+ xyxy = box.xyxy[0] if hasattr(box.xyxy, "__len__") and len(box.xyxy) > 0 else None
174
+ if xyxy is None:
175
+ continue
176
+ x1, y1, x2, y2 = (float(v) for v in xyxy)
177
+ detections.append({
178
  "class": cls_name,
179
  "mapped_label": _LABEL_MAP.get(cls_name, cls_name),
180
  "confidence": round(conf, 4),
181
  "bbox": [round(x1), round(y1), round(x2), round(y2)],
182
+ })
183
+ except Exception as exc:
184
+ logger.debug("Skipping box due to extraction error: %s", exc)
185
+ continue
186
+
187
  return detections
188
 
189
 
190
  # ---------------------------------------------------------------------------
191
  # 3. SSIM-based forensic heatmap
192
  # ---------------------------------------------------------------------------
 
 
193
  def build_heatmap(rgb: np.ndarray) -> str:
194
  """
195
  Generate a forensic heatmap by comparing the original image against a
196
+ Gaussian-blurred reference. High SSIM -> green; low SSIM -> red.
197
+ Returns a base64-encoded PNG heatmap (data URI).
 
198
  """
199
  gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY)
 
200
  blurred = cv2.GaussianBlur(gray, (15, 15), 0)
201
 
202
+ # Compute SSIM score map; fallback to simple difference if it fails
203
+ try:
204
+ _, ssim_map = ssim(gray, blurred, full=True, data_range=255)
205
+ except Exception as exc:
206
+ logger.warning("SSIM computation failed (%s); falling back to absdiff.", exc)
207
+ diff = cv2.absdiff(gray, blurred).astype(np.float32)
208
+ ssim_map = 1.0 - (diff / 255.0)
209
 
210
  # Normalise to [0, 255]
211
+ ssim_norm = ((ssim_map + 1.0) / 2.0 * 255.0).clip(0, 255).astype(np.uint8)
212
 
213
+ # Map to BGR: low similarity -> red, high -> green
214
  colormap = cv2.COLORMAP_RdYlGn if hasattr(cv2, "COLORMAP_RdYlGn") else cv2.COLORMAP_JET
215
  heatmap_bgr = cv2.applyColorMap(ssim_norm, colormap)
216
 
 
219
  overlay = cv2.addWeighted(rgb_bgr, 0.55, heatmap_bgr, 0.45, 0)
220
  overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
221
 
222
+ return _array_to_base64(overlay_rgb, fmt="PNG")