harelcain commited on
Commit
3eaf9e9
·
verified ·
1 Parent(s): 6ea3fc2

Upload 4 files

Browse files
Files changed (1) hide show
  1. app.py +83 -72
app.py CHANGED
@@ -26,7 +26,9 @@ import uvicorn
26
 
27
  def extract_features(img: np.ndarray) -> tuple:
28
  gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
29
- sift = cv2.SIFT_create(nfeatures=10000, contrastThreshold=0.02, edgeThreshold=15)
 
 
30
  keypoints, descriptors = sift.detectAndCompute(gray, None)
31
  return keypoints, descriptors
32
 
@@ -68,16 +70,12 @@ def compute_homography(kp1, kp2, matches, ransac_reproj_thresh=8.0, confidence=0
68
 
69
  def create_inlier_mask(keypoints, matches, inlier_mask, image_shape, radius=50):
70
  h, w = image_shape[:2]
71
- mask = np.zeros((h, w), dtype=bool)
72
  for i, m in enumerate(matches):
73
  if inlier_mask[i]:
74
  pt = keypoints[m.trainIdx].pt
75
- x, y = int(pt[0]), int(pt[1])
76
- y_min, y_max = max(0, y - radius), min(h, y + radius + 1)
77
- x_min, x_max = max(0, x - radius), min(w, x + radius + 1)
78
- yy, xx = np.ogrid[y_min:y_max, x_min:x_max]
79
- circle = (xx - x) ** 2 + (yy - y) ** 2 <= radius ** 2
80
- mask[y_min:y_max, x_min:x_max] |= circle
81
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (radius, radius))
82
  mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=2).astype(bool)
83
  return mask
@@ -90,12 +88,7 @@ def _build_histogram_lookup(src_channel, tgt_channel, n_bins=256):
90
  src_cdf = src_cdf / (src_cdf[-1] + 1e-10)
91
  tgt_cdf = np.cumsum(tgt_hist).astype(np.float64)
92
  tgt_cdf = tgt_cdf / (tgt_cdf[-1] + 1e-10)
93
- lookup = np.zeros(n_bins, dtype=np.uint8)
94
- tgt_idx = 0
95
- for src_idx in range(n_bins):
96
- while tgt_idx < n_bins - 1 and tgt_cdf[tgt_idx] < src_cdf[src_idx]:
97
- tgt_idx += 1
98
- lookup[src_idx] = tgt_idx
99
  return lookup
100
 
101
 
@@ -106,12 +99,7 @@ def _build_histogram_lookup_float(src_channel, tgt_channel, n_bins=256):
106
  src_cdf = src_cdf / (src_cdf[-1] + 1e-10)
107
  tgt_cdf = np.cumsum(tgt_hist).astype(np.float64)
108
  tgt_cdf = tgt_cdf / (tgt_cdf[-1] + 1e-10)
109
- lookup = np.zeros(n_bins, dtype=np.float32)
110
- tgt_idx = 0
111
- for src_idx in range(n_bins):
112
- while tgt_idx < n_bins - 1 and tgt_cdf[tgt_idx] < src_cdf[src_idx]:
113
- tgt_idx += 1
114
- lookup[src_idx] = tgt_idx
115
  return lookup
116
 
117
 
@@ -168,12 +156,7 @@ def piecewise_linear_histogram_transfer(source, target, n_bins=256, mask=None):
168
  src_cdf = src_cdf / (src_cdf[-1] + 1e-10)
169
  tgt_cdf = np.cumsum(tgt_hist).astype(np.float64)
170
  tgt_cdf = tgt_cdf / (tgt_cdf[-1] + 1e-10)
171
- lookup = np.zeros(n_bins, dtype=np.float32)
172
- tgt_idx = 0
173
- for src_idx in range(n_bins):
174
- while tgt_idx < n_bins - 1 and tgt_cdf[tgt_idx] < src_cdf[src_idx]:
175
- tgt_idx += 1
176
- lookup[src_idx] = tgt_idx
177
  src_img = source[:, :, c].astype(np.float32)
178
  src_floor = np.floor(src_img).astype(np.int32)
179
  src_ceil = np.minimum(src_floor + 1, n_bins - 1)
@@ -183,6 +166,22 @@ def piecewise_linear_histogram_transfer(source, target, n_bins=256, mask=None):
183
  return np.clip(result, 0, 255).astype(np.uint8)
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  def full_histogram_matching(source, target, mask=None):
187
  lab_matched = histogram_matching_lab(source, target, mask)
188
  cdf_matched = piecewise_linear_histogram_transfer(source, target, mask=mask)
@@ -236,30 +235,32 @@ def estimate_motion_blur(image):
236
  magnitude = np.log1p(np.abs(fshift))
237
  h, w = magnitude.shape
238
  cy, cx = h // 2, w // 2
239
- best_angle = 0.0
240
- min_energy = float('inf')
241
- max_energy = float('-inf')
242
  radius = min(h, w) // 4
243
- for angle_deg in range(0, 180, 5):
244
- angle_rad = np.deg2rad(angle_deg)
245
- dx, dy = np.cos(angle_rad), np.sin(angle_rad)
246
- energy, count = 0.0, 0
247
- for r in range(5, radius):
248
- x, y = int(cx + r * dx), int(cy + r * dy)
249
- if 0 <= x < w and 0 <= y < h:
250
- energy += magnitude[y, x]
251
- count += 1
252
- x, y = int(cx - r * dx), int(cy - r * dy)
253
- if 0 <= x < w and 0 <= y < h:
254
- energy += magnitude[y, x]
255
- count += 1
256
- if count > 0:
257
- avg_energy = energy / count
258
- if avg_energy < min_energy:
259
- min_energy = avg_energy
260
- best_angle = angle_deg
261
- if avg_energy > max_energy:
262
- max_energy = avg_energy
 
 
 
 
 
263
  blur_angle = (best_angle + 90) % 180
264
  anisotropy = (max_energy - min_energy) / (max_energy + 1e-6)
265
  kernel_size = 1 if anisotropy < 0.05 else max(1, int(anisotropy * 25))
@@ -271,17 +272,16 @@ def estimate_crf(image):
271
  h, w = gray.shape
272
  laplacian = cv2.Laplacian(gray, cv2.CV_64F)
273
  hf_energy = np.mean(np.abs(laplacian))
274
- block_diffs = []
275
- for x in range(4, w - 1, 4):
276
- block_diffs.append(np.mean(np.abs(gray[:, x] - gray[:, x - 1])))
277
- for y in range(4, h - 1, 4):
278
- block_diffs.append(np.mean(np.abs(gray[y, :] - gray[y - 1, :])))
279
- interior_diffs = []
280
- for x in range(3, w - 1, 4):
281
- if x % 4 != 0:
282
- interior_diffs.append(np.mean(np.abs(gray[:, x] - gray[:, x - 1])))
283
- avg_block = np.median(block_diffs) if block_diffs else 0
284
- avg_interior = np.median(interior_diffs) if interior_diffs else 1
285
  blockiness = avg_block / (avg_interior + 1e-6)
286
  if hf_energy > 30:
287
  crf_from_hf = 15
@@ -433,7 +433,7 @@ def paste_unedited_regions(aligned, target, mask):
433
 
434
  # ============== Alignment Pipeline ==============
435
 
436
- def align_image(source_img, target_img, pp_level=2):
437
  target_h, target_w = target_img.shape[:2]
438
  target_size = (target_w, target_h)
439
  source_resized = cv2.resize(source_img, target_size, interpolation=cv2.INTER_LANCZOS4)
@@ -458,18 +458,21 @@ def align_image(source_img, target_img, pp_level=2):
458
  else:
459
  aligned = source_resized
460
 
461
- result = full_histogram_matching(aligned, target_img, mask=color_mask)
462
 
463
- # Paste back unedited regions from target
464
  pre_paste = result.copy()
465
- unedited_mask = detect_unedited_mask(result, target_img)
466
- result = paste_unedited_regions(result, target_img, unedited_mask)
 
 
467
 
468
  # Post-processing (only affects edited regions, then re-paste)
469
  pp_result = None
470
  if pp_level > 0:
471
  pp_result = postprocess_foreground(result, target_img, level=pp_level)
472
- pp_result = paste_unedited_regions(pp_result, target_img, unedited_mask)
 
473
 
474
  final = pp_result if pp_result is not None else result
475
  return final, naive_resized, result, pre_paste, unedited_mask, pp_result
@@ -578,7 +581,8 @@ def encode_image_png(img: np.ndarray) -> bytes:
578
  async def align_api(
579
  source: UploadFile = File(..., description="Source image to align"),
580
  target: UploadFile = File(..., description="Target reference image"),
581
- pp: int = Form(2, description="Post-processing level 0-3 (0=none, default=2)")
 
582
  ):
583
  """
584
  Align source image to target image.
@@ -595,7 +599,7 @@ async def align_api(
595
  if source_img is None or target_img is None:
596
  raise HTTPException(status_code=400, detail="Failed to decode images")
597
 
598
- final, *_ = align_image(source_img, target_img, pp_level=pp_level)
599
  png_bytes = encode_image_png(final)
600
 
601
  return Response(content=png_bytes, media_type="image/png")
@@ -608,7 +612,8 @@ async def align_api(
608
  async def align_base64_api(
609
  source: UploadFile = File(...),
610
  target: UploadFile = File(...),
611
- pp: int = Form(2, description="Post-processing level 0-3 (0=none, default=2)")
 
612
  ):
613
  """
614
  Align source image to target image.
@@ -625,7 +630,7 @@ async def align_base64_api(
625
  if source_img is None or target_img is None:
626
  raise HTTPException(status_code=400, detail="Failed to decode images")
627
 
628
- final, *_ = align_image(source_img, target_img, pp_level=pp_level)
629
  png_bytes = encode_image_png(final)
630
  b64 = base64.b64encode(png_bytes).decode('utf-8')
631
 
@@ -639,7 +644,8 @@ async def align_base64_api(
639
  async def align_viz_api(
640
  source: UploadFile = File(...),
641
  target: UploadFile = File(...),
642
- pp: int = Form(2, description="Post-processing level 0-3 (0=none, default=2)")
 
643
  ):
644
  """
645
  Align source image to target and return visualization panel + final result.
@@ -656,7 +662,7 @@ async def align_viz_api(
656
  raise HTTPException(status_code=400, detail="Failed to decode images")
657
 
658
  final, naive_resized, pasted, pre_paste, unedited_mask, pp_result = \
659
- align_image(source_img, target_img, pp_level=pp_level)
660
 
661
  panel = create_visualization_panel(
662
  naive_resized, pasted, target_img,
@@ -848,6 +854,10 @@ HTML_CONTENT = """
848
  <option value="2" selected>2 - Medium (default)</option>
849
  <option value="3">3 - Strong</option>
850
  </select>
 
 
 
 
851
  </div>
852
 
853
  <button class="btn" id="alignBtn" disabled onclick="alignImages()">&#10024; Align Images</button>
@@ -932,6 +942,7 @@ console.log(data.image); // data:image/png;base64,...</code></pre>
932
  formData.append('source', sourceFile);
933
  formData.append('target', targetFile);
934
  formData.append('pp', document.getElementById('ppLevel').value);
 
935
 
936
  const response = await fetch('/api/align/viz', {
937
  method: 'POST',
 
26
 
27
  def extract_features(img: np.ndarray) -> tuple:
28
  gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
29
+ n_pixels = img.shape[0] * img.shape[1]
30
+ nfeatures = min(10000, max(2000, n_pixels // 200))
31
+ sift = cv2.SIFT_create(nfeatures=nfeatures, contrastThreshold=0.02, edgeThreshold=15)
32
  keypoints, descriptors = sift.detectAndCompute(gray, None)
33
  return keypoints, descriptors
34
 
 
70
 
71
  def create_inlier_mask(keypoints, matches, inlier_mask, image_shape, radius=50):
72
  h, w = image_shape[:2]
73
+ mask_img = np.zeros((h, w), dtype=np.uint8)
74
  for i, m in enumerate(matches):
75
  if inlier_mask[i]:
76
  pt = keypoints[m.trainIdx].pt
77
+ cv2.circle(mask_img, (int(pt[0]), int(pt[1])), radius, 1, -1)
78
+ mask = mask_img.astype(bool)
 
 
 
 
79
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (radius, radius))
80
  mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=2).astype(bool)
81
  return mask
 
88
  src_cdf = src_cdf / (src_cdf[-1] + 1e-10)
89
  tgt_cdf = np.cumsum(tgt_hist).astype(np.float64)
90
  tgt_cdf = tgt_cdf / (tgt_cdf[-1] + 1e-10)
91
+ lookup = np.searchsorted(tgt_cdf, src_cdf).astype(np.uint8)
 
 
 
 
 
92
  return lookup
93
 
94
 
 
99
  src_cdf = src_cdf / (src_cdf[-1] + 1e-10)
100
  tgt_cdf = np.cumsum(tgt_hist).astype(np.float64)
101
  tgt_cdf = tgt_cdf / (tgt_cdf[-1] + 1e-10)
102
+ lookup = np.searchsorted(tgt_cdf, src_cdf).astype(np.float32)
 
 
 
 
 
103
  return lookup
104
 
105
 
 
156
  src_cdf = src_cdf / (src_cdf[-1] + 1e-10)
157
  tgt_cdf = np.cumsum(tgt_hist).astype(np.float64)
158
  tgt_cdf = tgt_cdf / (tgt_cdf[-1] + 1e-10)
159
+ lookup = np.searchsorted(tgt_cdf, src_cdf).astype(np.float32)
 
 
 
 
 
160
  src_img = source[:, :, c].astype(np.float32)
161
  src_floor = np.floor(src_img).astype(np.int32)
162
  src_ceil = np.minimum(src_floor + 1, n_bins - 1)
 
166
  return np.clip(result, 0, 255).astype(np.uint8)
167
 
168
 
169
+ def fast_color_transfer(source, target, mask=None):
170
+ src_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
171
+ tgt_lab = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
172
+ if mask is not None:
173
+ src_stats = src_lab[mask]
174
+ tgt_stats = tgt_lab[mask]
175
+ else:
176
+ src_stats = src_lab.reshape(-1, 3)
177
+ tgt_stats = tgt_lab.reshape(-1, 3)
178
+ for i in range(3):
179
+ s_mean, s_std = src_stats[:, i].mean(), src_stats[:, i].std() + 1e-6
180
+ t_mean, t_std = tgt_stats[:, i].mean(), tgt_stats[:, i].std() + 1e-6
181
+ src_lab[:, :, i] = (src_lab[:, :, i] - s_mean) * (t_std / s_std) + t_mean
182
+ return cv2.cvtColor(np.clip(src_lab, 0, 255).astype(np.uint8), cv2.COLOR_LAB2BGR)
183
+
184
+
185
  def full_histogram_matching(source, target, mask=None):
186
  lab_matched = histogram_matching_lab(source, target, mask)
187
  cdf_matched = piecewise_linear_histogram_transfer(source, target, mask=mask)
 
235
  magnitude = np.log1p(np.abs(fshift))
236
  h, w = magnitude.shape
237
  cy, cx = h // 2, w // 2
 
 
 
238
  radius = min(h, w) // 4
239
+ angles_deg = np.arange(0, 180, 5)
240
+ angles_rad = np.deg2rad(angles_deg)
241
+ rs = np.arange(5, radius)
242
+ dx = np.cos(angles_rad)
243
+ dy = np.sin(angles_rad)
244
+ X_pos = (cx + np.outer(rs, dx)).astype(int)
245
+ Y_pos = (cy + np.outer(rs, dy)).astype(int)
246
+ X_neg = (cx - np.outer(rs, dx)).astype(int)
247
+ Y_neg = (cy - np.outer(rs, dy)).astype(int)
248
+ valid_pos = (X_pos >= 0) & (X_pos < w) & (Y_pos >= 0) & (Y_pos < h)
249
+ valid_neg = (X_neg >= 0) & (X_neg < w) & (Y_neg >= 0) & (Y_neg < h)
250
+ energy_pos = np.where(valid_pos, magnitude[np.clip(Y_pos, 0, h-1), np.clip(X_pos, 0, w-1)], 0.0)
251
+ energy_neg = np.where(valid_neg, magnitude[np.clip(Y_neg, 0, h-1), np.clip(X_neg, 0, w-1)], 0.0)
252
+ total_energy = energy_pos.sum(axis=0) + energy_neg.sum(axis=0)
253
+ total_count = valid_pos.sum(axis=0) + valid_neg.sum(axis=0)
254
+ valid_angles = total_count > 0
255
+ avg_energies = np.where(valid_angles, total_energy / (total_count + 1e-10), 0.0)
256
+ if valid_angles.any():
257
+ min_idx = np.argmin(np.where(valid_angles, avg_energies, np.inf))
258
+ max_idx = np.argmax(np.where(valid_angles, avg_energies, -np.inf))
259
+ best_angle = angles_deg[min_idx]
260
+ min_energy = avg_energies[min_idx]
261
+ max_energy = avg_energies[max_idx]
262
+ else:
263
+ best_angle, min_energy, max_energy = 0.0, 0.0, 0.0
264
  blur_angle = (best_angle + 90) % 180
265
  anisotropy = (max_energy - min_energy) / (max_energy + 1e-6)
266
  kernel_size = 1 if anisotropy < 0.05 else max(1, int(anisotropy * 25))
 
272
  h, w = gray.shape
273
  laplacian = cv2.Laplacian(gray, cv2.CV_64F)
274
  hf_energy = np.mean(np.abs(laplacian))
275
+ cols_4 = np.arange(4, w - 1, 4)
276
+ rows_4 = np.arange(4, h - 1, 4)
277
+ block_diffs_x = np.mean(np.abs(gray[:, cols_4] - gray[:, cols_4 - 1]), axis=0) if len(cols_4) else np.array([])
278
+ block_diffs_y = np.mean(np.abs(gray[rows_4, :] - gray[rows_4 - 1, :]), axis=1) if len(rows_4) else np.array([])
279
+ block_diffs = np.concatenate([block_diffs_x, block_diffs_y])
280
+ cols_interior = np.arange(3, w - 1, 4)
281
+ cols_interior = cols_interior[cols_interior % 4 != 0]
282
+ interior_diffs = np.mean(np.abs(gray[:, cols_interior] - gray[:, cols_interior - 1]), axis=0) if len(cols_interior) else np.array([])
283
+ avg_block = np.median(block_diffs) if len(block_diffs) else 0
284
+ avg_interior = np.median(interior_diffs) if len(interior_diffs) else 1
 
285
  blockiness = avg_block / (avg_interior + 1e-6)
286
  if hf_energy > 30:
287
  crf_from_hf = 15
 
433
 
434
  # ============== Alignment Pipeline ==============
435
 
436
+ def align_image(source_img, target_img, pp_level=2, paste_back=True):
437
  target_h, target_w = target_img.shape[:2]
438
  target_size = (target_w, target_h)
439
  source_resized = cv2.resize(source_img, target_size, interpolation=cv2.INTER_LANCZOS4)
 
458
  else:
459
  aligned = source_resized
460
 
461
+ result = fast_color_transfer(aligned, target_img, mask=color_mask)
462
 
463
+ # Optionally paste back unedited regions from target
464
  pre_paste = result.copy()
465
+ unedited_mask = None
466
+ if paste_back:
467
+ unedited_mask = detect_unedited_mask(result, target_img)
468
+ result = paste_unedited_regions(result, target_img, unedited_mask)
469
 
470
  # Post-processing (only affects edited regions, then re-paste)
471
  pp_result = None
472
  if pp_level > 0:
473
  pp_result = postprocess_foreground(result, target_img, level=pp_level)
474
+ if paste_back and unedited_mask is not None:
475
+ pp_result = paste_unedited_regions(pp_result, target_img, unedited_mask)
476
 
477
  final = pp_result if pp_result is not None else result
478
  return final, naive_resized, result, pre_paste, unedited_mask, pp_result
 
581
  async def align_api(
582
  source: UploadFile = File(..., description="Source image to align"),
583
  target: UploadFile = File(..., description="Target reference image"),
584
+ pp: int = Form(2, description="Post-processing level 0-3 (0=none, default=2)"),
585
+ paste_back: bool = Form(True, description="Paste back unedited regions from target (default=true)")
586
  ):
587
  """
588
  Align source image to target image.
 
599
  if source_img is None or target_img is None:
600
  raise HTTPException(status_code=400, detail="Failed to decode images")
601
 
602
+ final, *_ = align_image(source_img, target_img, pp_level=pp_level, paste_back=paste_back)
603
  png_bytes = encode_image_png(final)
604
 
605
  return Response(content=png_bytes, media_type="image/png")
 
612
  async def align_base64_api(
613
  source: UploadFile = File(...),
614
  target: UploadFile = File(...),
615
+ pp: int = Form(2, description="Post-processing level 0-3 (0=none, default=2)"),
616
+ paste_back: bool = Form(True, description="Paste back unedited regions from target (default=true)")
617
  ):
618
  """
619
  Align source image to target image.
 
630
  if source_img is None or target_img is None:
631
  raise HTTPException(status_code=400, detail="Failed to decode images")
632
 
633
+ final, *_ = align_image(source_img, target_img, pp_level=pp_level, paste_back=paste_back)
634
  png_bytes = encode_image_png(final)
635
  b64 = base64.b64encode(png_bytes).decode('utf-8')
636
 
 
644
  async def align_viz_api(
645
  source: UploadFile = File(...),
646
  target: UploadFile = File(...),
647
+ pp: int = Form(2, description="Post-processing level 0-3 (0=none, default=2)"),
648
+ paste_back: bool = Form(True, description="Paste back unedited regions from target (default=true)")
649
  ):
650
  """
651
  Align source image to target and return visualization panel + final result.
 
662
  raise HTTPException(status_code=400, detail="Failed to decode images")
663
 
664
  final, naive_resized, pasted, pre_paste, unedited_mask, pp_result = \
665
+ align_image(source_img, target_img, pp_level=pp_level, paste_back=paste_back)
666
 
667
  panel = create_visualization_panel(
668
  naive_resized, pasted, target_img,
 
854
  <option value="2" selected>2 - Medium (default)</option>
855
  <option value="3">3 - Strong</option>
856
  </select>
857
+ <label style="display:flex;align-items:center;gap:0.4rem;cursor:pointer;">
858
+ <input type="checkbox" id="pasteBack" checked style="width:18px;height:18px;cursor:pointer;">
859
+ Paste back unedited regions
860
+ </label>
861
  </div>
862
 
863
  <button class="btn" id="alignBtn" disabled onclick="alignImages()">&#10024; Align Images</button>
 
942
  formData.append('source', sourceFile);
943
  formData.append('target', targetFile);
944
  formData.append('pp', document.getElementById('ppLevel').value);
945
+ formData.append('paste_back', document.getElementById('pasteBack').checked ? 'true' : 'false');
946
 
947
  const response = await fetch('/api/align/viz', {
948
  method: 'POST',