Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Developed by Nikhil Nageshwar Inturi | |
| This module provides MaskStitcher for stitching tiled .npy masks | |
| back into full-size masks, one per original image stem. | |
| """ | |
| import re | |
| from pathlib import Path | |
| import numpy as np | |
| import logging | |
| class NPYMaskStitcher: | |
| """ | |
| Scans an input directory for files matching | |
| <stem>_<row>_<col>.npy, groups them by stem, and | |
| stitches each group into a single full-size mask. | |
| """ | |
| TILE_PATTERN = re.compile(r'^(?P<stem>.+)_(?P<row>\d+)_(?P<col>\d+)\.npy$') | |
| def __init__(self, input_dir: Path, output_dir: Path) -> None: | |
| self.input_dir = Path(input_dir) | |
| self.output_dir = Path(output_dir) | |
| self.logger = logging.getLogger(self.__class__.__name__) | |
| self._setup_output_directory() | |
| def _setup_output_directory(self) -> None: | |
| try: | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| self.logger.debug(f"Output directory ready: {self.output_dir}") | |
| except Exception as e: | |
| self.logger.error(f"Could not create output directory {self.output_dir}: {e}") | |
| raise | |
| def stitch_all(self) -> None: | |
| """ | |
| Find all .npy tiles, group by stem, and stitch each group. | |
| """ | |
| all_files = list(self.input_dir.glob("*.npy")) | |
| if not all_files: | |
| self.logger.warning(f"No .npy files found in {self.input_dir}") | |
| return | |
| # group files by stem | |
| stems = {} | |
| for p in all_files: | |
| m = self.TILE_PATTERN.match(p.name) | |
| if not m: | |
| self.logger.warning(f"Skipping unrecognized file name: {p.name}") | |
| continue | |
| stem = m.group("stem") | |
| stems.setdefault(stem, []).append(p) | |
| for stem, paths in stems.items(): | |
| try: | |
| self._stitch_stem(stem, paths) | |
| self.logger.info(f"Stitched mask for '{stem}' → {stem}.npy") | |
| except Exception: | |
| self.logger.exception(f"Failed to stitch tiles for '{stem}'") | |
| def _stitch_stem(self, stem: str, paths: list[Path]) -> None: | |
| """ | |
| Given all tile paths for a single stem, reconstruct the full mask. | |
| """ | |
| # load each tile into a dict keyed by (row, col) | |
| mask_map = {} | |
| rows = set() | |
| cols = set() | |
| for p in paths: | |
| m = self.TILE_PATTERN.match(p.name) | |
| row, col = int(m.group("row")), int(m.group("col")) | |
| tile = np.load(p) | |
| mask_map[(row, col)] = tile | |
| rows.add(row) | |
| cols.add(col) | |
| all_rows = sorted(rows) | |
| all_cols = sorted(cols) | |
| # determine max height per row, max width per col | |
| row_heights = {r: max(mask_map[(r, c)].shape[0] | |
| for c in all_cols if (r, c) in mask_map) | |
| for r in all_rows} | |
| col_widths = {c: max(mask_map[(r, c)].shape[1] | |
| for r in all_rows if (r, c) in mask_map) | |
| for c in all_cols} | |
| # compute offsets | |
| row_offsets = {r: sum(row_heights[rr] for rr in all_rows if rr < r) | |
| for r in all_rows} | |
| col_offsets = {c: sum(col_widths[cc] for cc in all_cols if cc < c) | |
| for c in all_cols} | |
| # total dims | |
| total_h = sum(row_heights.values()) | |
| total_w = sum(col_widths.values()) | |
| # create canvas | |
| full_mask = np.zeros((total_h, total_w), dtype=np.uint16) | |
| # place tiles | |
| for (r, c), tile in mask_map.items(): | |
| y0, x0 = row_offsets[r], col_offsets[c] | |
| h, w = tile.shape | |
| full_mask[y0:y0+h, x0:x0+w] = tile | |
| # save combined mask | |
| out_path = self.output_dir / f"{stem}.npy" | |
| np.save(out_path, full_mask) | |
| # # Path to mask files | |
| # mask_folder = image_dir # update this | |
| # mask_files = [f for f in os.listdir(mask_folder) if f.endswith('.npy')] | |
| # # Pattern to extract row and column | |
| # pattern = re.compile(r'_(\d+)_(\d+)\.npy') | |
| # # Map to hold each mask and its (row, col) | |
| # mask_map = {} | |
| # row_col_set = set() | |
| # # Organize masks by (row, col) | |
| # for f in mask_files: | |
| # match = pattern.search(f) | |
| # if match: | |
| # row = int(match.group(1)) # y | |
| # col = int(match.group(2)) # x | |
| # mask = np.load(os.path.join(mask_folder, f)) | |
| # mask_map[(row, col)] = mask | |
| # row_col_set.add((row, col)) | |
| # # Determine row and column counts | |
| # all_rows = sorted({r for r, _ in row_col_set}) | |
| # all_cols = sorted({c for _, c in row_col_set}) | |
| # # Build a lookup for tile dimensions per row/col | |
| # row_heights = {} | |
| # col_widths = {} | |
| # for row in all_rows: | |
| # for col in all_cols: | |
| # if (row, col) in mask_map: | |
| # h, w = mask_map[(row, col)].shape | |
| # row_heights[row] = max(row_heights.get(row, 0), h) | |
| # col_widths[col] = max(col_widths.get(col, 0), w) | |
| # # Compute cumulative row/column positions | |
| # row_offsets = {r: sum(row_heights[rr] for rr in all_rows if rr < r) for r in all_rows} | |
| # col_offsets = {c: sum(col_widths[cc] for cc in all_cols if cc < c) for c in all_cols} | |
| # # Total dimensions | |
| # total_height = sum(row_heights[r] for r in all_rows) | |
| # total_width = sum(col_widths[c] for c in all_cols) | |
| # # Create blank canvas | |
| # combined_mask = np.zeros((total_height, total_width), dtype=np.uint16) | |
| # # Stitch masks into the full canvas | |
| # for (row, col), mask in mask_map.items(): | |
| # y = row_offsets[row] | |
| # x = col_offsets[col] | |
| # h, w = mask.shape | |
| # combined_mask[y:y+h, x:x+w] = mask | |
| # # Save result | |
| # np.save('combined_full_mask_testing_model.npy', combined_mask) |