Spaces:
Sleeping
Sleeping
File size: 5,699 Bytes
c843d82 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | #!/usr/bin/env python3
"""
Developed by Nikhil Nageshwar Inturi
This module provides MaskStitcher for stitching tiled .npy masks
back into full-size masks, one per original image stem.
"""
import re
from pathlib import Path
import numpy as np
import logging
class NPYMaskStitcher:
"""
Scans an input directory for files matching
<stem>_<row>_<col>.npy, groups them by stem, and
stitches each group into a single full-size mask.
"""
TILE_PATTERN = re.compile(r'^(?P<stem>.+)_(?P<row>\d+)_(?P<col>\d+)\.npy$')
def __init__(self, input_dir: Path, output_dir: Path) -> None:
self.input_dir = Path(input_dir)
self.output_dir = Path(output_dir)
self.logger = logging.getLogger(self.__class__.__name__)
self._setup_output_directory()
def _setup_output_directory(self) -> None:
try:
self.output_dir.mkdir(parents=True, exist_ok=True)
self.logger.debug(f"Output directory ready: {self.output_dir}")
except Exception as e:
self.logger.error(f"Could not create output directory {self.output_dir}: {e}")
raise
def stitch_all(self) -> None:
"""
Find all .npy tiles, group by stem, and stitch each group.
"""
all_files = list(self.input_dir.glob("*.npy"))
if not all_files:
self.logger.warning(f"No .npy files found in {self.input_dir}")
return
# group files by stem
stems = {}
for p in all_files:
m = self.TILE_PATTERN.match(p.name)
if not m:
self.logger.warning(f"Skipping unrecognized file name: {p.name}")
continue
stem = m.group("stem")
stems.setdefault(stem, []).append(p)
for stem, paths in stems.items():
try:
self._stitch_stem(stem, paths)
self.logger.info(f"Stitched mask for '{stem}' → {stem}.npy")
except Exception:
self.logger.exception(f"Failed to stitch tiles for '{stem}'")
def _stitch_stem(self, stem: str, paths: list[Path]) -> None:
"""
Given all tile paths for a single stem, reconstruct the full mask.
"""
# load each tile into a dict keyed by (row, col)
mask_map = {}
rows = set()
cols = set()
for p in paths:
m = self.TILE_PATTERN.match(p.name)
row, col = int(m.group("row")), int(m.group("col"))
tile = np.load(p)
mask_map[(row, col)] = tile
rows.add(row)
cols.add(col)
all_rows = sorted(rows)
all_cols = sorted(cols)
# determine max height per row, max width per col
row_heights = {r: max(mask_map[(r, c)].shape[0]
for c in all_cols if (r, c) in mask_map)
for r in all_rows}
col_widths = {c: max(mask_map[(r, c)].shape[1]
for r in all_rows if (r, c) in mask_map)
for c in all_cols}
# compute offsets
row_offsets = {r: sum(row_heights[rr] for rr in all_rows if rr < r)
for r in all_rows}
col_offsets = {c: sum(col_widths[cc] for cc in all_cols if cc < c)
for c in all_cols}
# total dims
total_h = sum(row_heights.values())
total_w = sum(col_widths.values())
# create canvas
full_mask = np.zeros((total_h, total_w), dtype=np.uint16)
# place tiles
for (r, c), tile in mask_map.items():
y0, x0 = row_offsets[r], col_offsets[c]
h, w = tile.shape
full_mask[y0:y0+h, x0:x0+w] = tile
# save combined mask
out_path = self.output_dir / f"{stem}.npy"
np.save(out_path, full_mask)
# # Path to mask files
# mask_folder = image_dir # update this
# mask_files = [f for f in os.listdir(mask_folder) if f.endswith('.npy')]
# # Pattern to extract row and column
# pattern = re.compile(r'_(\d+)_(\d+)\.npy')
# # Map to hold each mask and its (row, col)
# mask_map = {}
# row_col_set = set()
# # Organize masks by (row, col)
# for f in mask_files:
# match = pattern.search(f)
# if match:
# row = int(match.group(1)) # y
# col = int(match.group(2)) # x
# mask = np.load(os.path.join(mask_folder, f))
# mask_map[(row, col)] = mask
# row_col_set.add((row, col))
# # Determine row and column counts
# all_rows = sorted({r for r, _ in row_col_set})
# all_cols = sorted({c for _, c in row_col_set})
# # Build a lookup for tile dimensions per row/col
# row_heights = {}
# col_widths = {}
# for row in all_rows:
# for col in all_cols:
# if (row, col) in mask_map:
# h, w = mask_map[(row, col)].shape
# row_heights[row] = max(row_heights.get(row, 0), h)
# col_widths[col] = max(col_widths.get(col, 0), w)
# # Compute cumulative row/column positions
# row_offsets = {r: sum(row_heights[rr] for rr in all_rows if rr < r) for r in all_rows}
# col_offsets = {c: sum(col_widths[cc] for cc in all_cols if cc < c) for c in all_cols}
# # Total dimensions
# total_height = sum(row_heights[r] for r in all_rows)
# total_width = sum(col_widths[c] for c in all_cols)
# # Create blank canvas
# combined_mask = np.zeros((total_height, total_width), dtype=np.uint16)
# # Stitch masks into the full canvas
# for (row, col), mask in mask_map.items():
# y = row_offsets[row]
# x = col_offsets[col]
# h, w = mask.shape
# combined_mask[y:y+h, x:x+w] = mask
# # Save result
# np.save('combined_full_mask_testing_model.npy', combined_mask) |