| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from transformers import Pipeline,AutoModel |
| from tqdm import tqdm |
|
|
|
|
|
|
| class InkDetectionPipeline(Pipeline): |
| """ |
| A custom pipeline that: |
| 1. Takes in a 3D image: shape (m, n, d). |
| 2. Cuts it into (64, 64, d) tiles with a given stride. |
| 3. Runs the model inference on each tile (model is 3D-to-2D). |
| 4. Reconstructs the predictions into a full-size output. |
| """ |
|
|
| def __init__(self, model, device='cuda', tile_size=64, stride=32, scale_factor=16,batch_size=32,**kwargs): |
| super().__init__(model=model, tokenizer=None, device=0 if device=='cuda' else -1) |
| self.model = model.to(device) |
| self.device = device |
| self.tile_size = tile_size |
| self.stride = stride |
| self.scale_factor = scale_factor |
| self.batch_size=batch_size |
| def preprocess(self, inputs): |
| """ |
| inputs: np.ndarray of shape (m, n, d) |
| This function cuts the input volume into tiles of shape (tile_size, tile_size, d) |
| with a given stride. |
| Returns: |
| tiles: list of np arrays each (tile_size, tile_size, d) |
| coords: list of (x1, y1, x2, y2) coords |
| """ |
| volume = inputs |
| m, n, d = volume.shape |
| tiles = [] |
| coords = [] |
| |
| for y in range(0, m - self.tile_size + 1, self.stride): |
| for x in range(0, n - self.tile_size + 1, self.stride): |
| y1, y2 = y, y + self.tile_size |
| x1, x2 = x, x + self.tile_size |
| patch = volume[y1:y2, x1:x2] |
| tiles.append(patch.transpose(2,0,1)) |
| coords.append((x1, y1, x2, y2)) |
| return np.array(tiles,dtype=np.float16), coords, (m, n) |
| def _forward(self, model_inputs): |
| """ |
| model_inputs: a list of patches (B, tile_size, tile_size, d) |
| The model expects input: (B, C=1, H=tile_size, W=tile_size) |
| and returns (B, 1, H=tile_size, W=tile_size). |
| |
| We'll add batching using a for loop. We assume `self.batch_size` is defined. |
| """ |
|
|
| patches = model_inputs |
| B = len(patches) |
| all_preds = [] |
| |
| for start_idx in tqdm(range(0, B, self.batch_size)): |
| end_idx = start_idx + self.batch_size |
| sub_batch = torch.from_numpy(patches[start_idx:end_idx].astype(np.float32)) |
|
|
| |
| sub_batch = sub_batch.unsqueeze(1) |
|
|
| with torch.no_grad(), torch.autocast(self.device if self.device == 'cuda' else 'cpu'): |
| sub_y_preds = self.model(sub_batch.to(self.device)) |
|
|
| |
| sub_y_preds = torch.sigmoid(sub_y_preds) |
|
|
| |
| sub_y_preds = sub_y_preds.detach().cpu().float().numpy() |
| |
|
|
| all_preds.append(sub_y_preds) |
|
|
| |
| y_preds = np.concatenate(all_preds, axis=0) |
|
|
| return y_preds |
|
|
| def postprocess(self, model_outputs, coords, full_shape): |
| """ |
| model_outputs: np.ndarray of shape (B, 1, tile_size, tile_size) |
| coords: list of (x1, y1, x2, y2) for each tile |
| full_shape: (m,n) |
| |
| We need to: |
| - Place each tile prediction into a full (m,n) array |
| - Use the kernel to weight and sum predictions |
| - Divide by count |
| - Optionally upsample by scale_factor if required |
| """ |
| m, n = full_shape |
| |
| mask_pred = np.zeros((m, n), dtype=np.float32) |
| mask_count = np.zeros((m, n), dtype=np.float32) |
| B = model_outputs.shape[0] |
| |
| |
| preds_tensor = torch.from_numpy(model_outputs.astype(np.float32)) |
| if self.scale_factor != 1: |
| preds_tensor = F.interpolate( |
| preds_tensor, scale_factor=self.scale_factor, mode='bilinear', align_corners=False |
| ) |
| preds_tensor = preds_tensor.squeeze(1).numpy() |
|
|
| out_tile_size = self.tile_size |
|
|
| for i, (x1, y1, x2, y2) in enumerate(coords): |
| |
| y2_up = y1 + out_tile_size |
| x2_up = x1 + out_tile_size |
|
|
| mask_pred[y1:y2_up, x1:x2_up] += preds_tensor[i] |
| mask_count[y1:y2_up, x1:x2_up] += np.ones((out_tile_size, out_tile_size), dtype=np.float32) |
|
|
| mask_pred = np.divide(mask_pred, mask_count, out=np.zeros_like(mask_pred), where=mask_count!=0) |
| |
| return mask_pred |
| def _sanitize_parameters(self,**kwargs): |
| return {},{},{} |
| def __call__(self, image: np.ndarray): |
| """ |
| Args: |
| image: np.ndarray of shape (m, n, d) input volume. |
| Returns: |
| mask_pred: np.ndarray of shape (m_out, n_out) predicted mask. |
| """ |
| tiles, coords, full_shape = self.preprocess(image) |
| |
| |
| outputs = self._forward(tiles) |
| mask_pred = self.postprocess(outputs, coords, full_shape) |
| return mask_pred |