Update inference.py
Browse files- inference.py +12 -5
inference.py
CHANGED
|
@@ -90,7 +90,7 @@ class DTDPredictor:
|
|
| 90 |
The model was trained on quantized coefficients, so results may vary.
|
| 91 |
|
| 92 |
Args:
|
| 93 |
-
image_array: RGB image as numpy array
|
| 94 |
quality: JPEG quality (used for QT approximation)
|
| 95 |
|
| 96 |
Returns:
|
|
@@ -101,11 +101,12 @@ class DTDPredictor:
|
|
| 101 |
im_ycbcr = cv2.cvtColor(image_array, cv2.COLOR_RGB2YCrCb)
|
| 102 |
y_channel = im_ycbcr[:, :, 0].astype(np.float32) - 128 # Center around 0
|
| 103 |
|
| 104 |
-
#
|
| 105 |
h, w = y_channel.shape
|
|
|
|
| 106 |
|
| 107 |
# Compute DCT for each 8x8 block
|
| 108 |
-
dct_coeffs = np.
|
| 109 |
for i in range(0, h, 8):
|
| 110 |
for j in range(0, w, 8):
|
| 111 |
block = y_channel[i:i+8, j:j+8]
|
|
@@ -289,8 +290,14 @@ class DTDPredictor:
|
|
| 289 |
full_mask = full_mask / np.maximum(count_map, 1)
|
| 290 |
final_mask = (full_mask > 0.5).astype(np.uint8)
|
| 291 |
|
| 292 |
-
#
|
| 293 |
-
final_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
# Create heatmap overlay with original image (no padding)
|
| 296 |
heatmap = self.create_heatmap(im_orig_np, final_mask)
|
|
|
|
| 90 |
The model was trained on quantized coefficients, so results may vary.
|
| 91 |
|
| 92 |
Args:
|
| 93 |
+
image_array: RGB image as numpy array (must be 8x8 aligned)
|
| 94 |
quality: JPEG quality (used for QT approximation)
|
| 95 |
|
| 96 |
Returns:
|
|
|
|
| 101 |
im_ycbcr = cv2.cvtColor(image_array, cv2.COLOR_RGB2YCrCb)
|
| 102 |
y_channel = im_ycbcr[:, :, 0].astype(np.float32) - 128 # Center around 0
|
| 103 |
|
| 104 |
+
# Image should already be 8x8 aligned
|
| 105 |
h, w = y_channel.shape
|
| 106 |
+
assert h % 8 == 0 and w % 8 == 0, f"Image must be 8x8 aligned, got {h}x{w}"
|
| 107 |
|
| 108 |
# Compute DCT for each 8x8 block
|
| 109 |
+
dct_coeffs = np.zeros((h, w), dtype=np.float32)
|
| 110 |
for i in range(0, h, 8):
|
| 111 |
for j in range(0, w, 8):
|
| 112 |
block = y_channel[i:i+8, j:j+8]
|
|
|
|
| 290 |
full_mask = full_mask / np.maximum(count_map, 1)
|
| 291 |
final_mask = (full_mask > 0.5).astype(np.uint8)
|
| 292 |
|
| 293 |
+
# Pad mask back to true original size if it was 8x8 aligned smaller
|
| 294 |
+
if final_mask.shape[0] < true_orig_h or final_mask.shape[1] < true_orig_w:
|
| 295 |
+
padded_mask = np.zeros((true_orig_h, true_orig_w), dtype=np.uint8)
|
| 296 |
+
padded_mask[:final_mask.shape[0], :final_mask.shape[1]] = final_mask
|
| 297 |
+
final_mask = padded_mask
|
| 298 |
+
else:
|
| 299 |
+
# Crop if somehow larger (shouldn't happen)
|
| 300 |
+
final_mask = final_mask[:true_orig_h, :true_orig_w]
|
| 301 |
|
| 302 |
# Create heatmap overlay with original image (no padding)
|
| 303 |
heatmap = self.create_heatmap(im_orig_np, final_mask)
|