astrosbd commited on
Commit
6da6870
·
verified ·
1 Parent(s): 01a9aca

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +12 -5
inference.py CHANGED
@@ -90,7 +90,7 @@ class DTDPredictor:
90
  The model was trained on quantized coefficients, so results may vary.
91
 
92
  Args:
93
- image_array: RGB image as numpy array
94
  quality: JPEG quality (used for QT approximation)
95
 
96
  Returns:
@@ -101,11 +101,12 @@ class DTDPredictor:
101
  im_ycbcr = cv2.cvtColor(image_array, cv2.COLOR_RGB2YCrCb)
102
  y_channel = im_ycbcr[:, :, 0].astype(np.float32) - 128 # Center around 0
103
 
104
- # Apply 2D DCT in 8x8 blocks
105
  h, w = y_channel.shape
 
106
 
107
  # Compute DCT for each 8x8 block
108
- dct_coeffs = np.zeros_like(y_channel)
109
  for i in range(0, h, 8):
110
  for j in range(0, w, 8):
111
  block = y_channel[i:i+8, j:j+8]
@@ -289,8 +290,14 @@ class DTDPredictor:
289
  full_mask = full_mask / np.maximum(count_map, 1)
290
  final_mask = (full_mask > 0.5).astype(np.uint8)
291
 
292
- # Crop back to TRUE original size (before any padding)
293
- final_mask = final_mask[:true_orig_h, :true_orig_w]
 
 
 
 
 
 
294
 
295
  # Create heatmap overlay with original image (no padding)
296
  heatmap = self.create_heatmap(im_orig_np, final_mask)
 
90
  The model was trained on quantized coefficients, so results may vary.
91
 
92
  Args:
93
+ image_array: RGB image as numpy array (must be 8x8 aligned)
94
  quality: JPEG quality (used for QT approximation)
95
 
96
  Returns:
 
101
  im_ycbcr = cv2.cvtColor(image_array, cv2.COLOR_RGB2YCrCb)
102
  y_channel = im_ycbcr[:, :, 0].astype(np.float32) - 128 # Center around 0
103
 
104
+ # Image should already be 8x8 aligned
105
  h, w = y_channel.shape
106
+ assert h % 8 == 0 and w % 8 == 0, f"Image must be 8x8 aligned, got {h}x{w}"
107
 
108
  # Compute DCT for each 8x8 block
109
+ dct_coeffs = np.zeros((h, w), dtype=np.float32)
110
  for i in range(0, h, 8):
111
  for j in range(0, w, 8):
112
  block = y_channel[i:i+8, j:j+8]
 
290
  full_mask = full_mask / np.maximum(count_map, 1)
291
  final_mask = (full_mask > 0.5).astype(np.uint8)
292
 
293
+ # Pad mask back to true original size if it was 8x8 aligned smaller
294
+ if final_mask.shape[0] < true_orig_h or final_mask.shape[1] < true_orig_w:
295
+ padded_mask = np.zeros((true_orig_h, true_orig_w), dtype=np.uint8)
296
+ padded_mask[:final_mask.shape[0], :final_mask.shape[1]] = final_mask
297
+ final_mask = padded_mask
298
+ else:
299
+ # Crop if somehow larger (shouldn't happen)
300
+ final_mask = final_mask[:true_orig_h, :true_orig_w]
301
 
302
  # Create heatmap overlay with original image (no padding)
303
  heatmap = self.create_heatmap(im_orig_np, final_mask)