| """ |
| Segmentation model wrapper for VQ-VAE based binary mask generation. |
| CPU-first implementation with tile-based inference for large images. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import io |
| import logging |
| import os |
| from typing import Any, Dict, Tuple |
|
|
| import cv2 |
| import numpy as np |
| from PIL import Image |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision.transforms as transforms |
|
|
| from transformers import AutoModel |
| from huggingface_hub import hf_hub_download |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| HF_SEGMENTATION_REPO = "kimaan28/VQVAE_Segmentation" |
| HF_SEGMENTATION_FILENAME = "vqvae_best_model.pth" |
|
|
| SEGMENTATION_MODEL_DEFAULT_PATH = os.path.join( |
| os.path.dirname(__file__), |
| "..", |
| "ml_models", |
| "segmentation", |
| "vqvae_best_model.pth", |
| ) |
|
|
|
|
| class ResidualBlock(nn.Module): |
| def __init__(self, dim: int): |
| super().__init__() |
| self.block = nn.Sequential( |
| nn.ReLU(), |
| nn.Conv2d(dim, dim, 3, 1, 1), |
| nn.BatchNorm2d(dim), |
| nn.ReLU(), |
| nn.Conv2d(dim, dim, 3, 1, 1), |
| nn.BatchNorm2d(dim), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x + self.block(x) |
|
|
|
|
| class VectorQuantizer(nn.Module): |
| def __init__( |
| self, |
| num_embeddings: int, |
| embedding_dim: int, |
| commitment_cost: float = 1.0, |
| decay: float = 0.99, |
| eps: float = 1e-5, |
| ): |
| super().__init__() |
| self.embedding_dim = embedding_dim |
| self.num_embeddings = num_embeddings |
| self.commitment_cost = commitment_cost |
| self.decay = decay |
| self.eps = eps |
|
|
| self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) |
| self.embedding.weight.data.uniform_(-1 / num_embeddings, 1 / num_embeddings) |
|
|
| self.register_buffer("cluster_size", torch.zeros(num_embeddings)) |
| self.register_buffer("embedding_avg", self.embedding.weight.data.clone()) |
|
|
| def forward(self, inputs: torch.Tensor): |
| inputs = inputs.permute(0, 2, 3, 1).contiguous() |
| input_shape = inputs.shape |
| flat_input = inputs.view(-1, self.embedding_dim) |
|
|
| distances = ( |
| torch.sum(flat_input**2, dim=1, keepdim=True) |
| + torch.sum(self.embedding.weight**2, dim=1) |
| - 2 * torch.matmul(flat_input, self.embedding.weight.t()) |
| ) |
| encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) |
| encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=inputs.device) |
| encodings.scatter_(1, encoding_indices, 1) |
|
|
| quantized = torch.matmul(encodings, self.embedding.weight).view(input_shape) |
|
|
| e_latent_loss = F.mse_loss(quantized.detach(), inputs) |
| q_latent_loss = F.mse_loss(quantized, inputs.detach()) |
| loss = q_latent_loss + self.commitment_cost * e_latent_loss |
|
|
| quantized = inputs + (quantized - inputs).detach() |
|
|
| if self.training and self.decay is not None: |
| n = torch.sum(encodings, dim=0) |
| self.cluster_size.data.mul_(self.decay).add_(n, alpha=1 - self.decay) |
|
|
| dw = torch.matmul(encodings.t(), flat_input) |
| self.embedding_avg.data.mul_(self.decay).add_(dw, alpha=1 - self.decay) |
|
|
| normalized_avg = self.embedding_avg / (self.cluster_size.unsqueeze(1) + self.eps) |
| self.embedding.weight.data.copy_(normalized_avg) |
|
|
| avg_probs = torch.mean(encodings, dim=0) |
| perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) |
|
|
| return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings |
|
|
|
|
| class VQVAE(nn.Module): |
| def __init__( |
| self, |
| input_channels: int = 3, |
| hidden_dim: int = 256, |
| embedding_dim: int = 256, |
| num_embeddings: int = 4096, |
| commitment_cost: float = 1.0, |
| ): |
| super().__init__() |
| self.encoder = nn.Sequential( |
| nn.Conv2d(input_channels, hidden_dim // 4, 4, 2, 1), |
| nn.BatchNorm2d(hidden_dim // 4), |
| nn.LeakyReLU(0.2), |
| ResidualBlock(hidden_dim // 4), |
| nn.Conv2d(hidden_dim // 4, hidden_dim // 2, 4, 2, 1), |
| nn.BatchNorm2d(hidden_dim // 2), |
| nn.LeakyReLU(0.2), |
| ResidualBlock(hidden_dim // 2), |
| nn.Conv2d(hidden_dim // 2, hidden_dim, 4, 2, 1), |
| nn.BatchNorm2d(hidden_dim), |
| nn.LeakyReLU(0.2), |
| ResidualBlock(hidden_dim), |
| nn.Conv2d(hidden_dim, embedding_dim, 3, 1, 1), |
| nn.LeakyReLU(0.2), |
| ) |
| self.vq = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost) |
| self.decoder = nn.Sequential( |
| nn.Conv2d(embedding_dim, hidden_dim, 3, 1, 1), |
| nn.BatchNorm2d(hidden_dim), |
| nn.LeakyReLU(0.2), |
| ResidualBlock(hidden_dim), |
| nn.ConvTranspose2d(hidden_dim, hidden_dim // 2, 4, 2, 1), |
| nn.BatchNorm2d(hidden_dim // 2), |
| nn.LeakyReLU(0.2), |
| ResidualBlock(hidden_dim // 2), |
| nn.ConvTranspose2d(hidden_dim // 2, hidden_dim // 4, 4, 2, 1), |
| nn.BatchNorm2d(hidden_dim // 4), |
| nn.LeakyReLU(0.2), |
| ResidualBlock(hidden_dim // 4), |
| nn.ConvTranspose2d(hidden_dim // 4, input_channels, 4, 2, 1), |
| nn.Sigmoid(), |
| ) |
|
|
| def forward(self, x: torch.Tensor): |
| z = self.encoder(x) |
| vq_loss, quantized, perplexity, _ = self.vq(z) |
| recon_x = self.decoder(quantized) |
| return recon_x, vq_loss, perplexity |
|
|
|
|
| class SegmentationPipeline: |
| def __init__(self, model_path: str, tile_size: int = 384, embedding_dim: int = 256): |
| self.model_path = os.path.abspath(model_path) |
| self.tile_size = tile_size |
| self.embedding_dim = embedding_dim |
| self.device = torch.device("cpu") |
|
|
| self.model = VQVAE( |
| input_channels=3, |
| hidden_dim=256, |
| embedding_dim=self.embedding_dim, |
| num_embeddings=4096, |
| commitment_cost=1.0, |
| ).to(self.device) |
|
|
| self.to_tensor = transforms.ToTensor() |
|
|
| def load(self) -> None: |
| """Load model weights from local path or HuggingFace Hub.""" |
| checkpoint_file = None |
| |
| |
| local_paths = [ |
| self.model_path, |
| SEGMENTATION_MODEL_DEFAULT_PATH, |
| "./ml_models/segmentation/vqvae_best_model.pth", |
| "../ml_models/segmentation/vqvae_best_model.pth", |
| ] |
| |
| for path in local_paths: |
| if os.path.exists(path): |
| checkpoint_file = os.path.abspath(path) |
| logger.info("[SEGMENTATION] Found local model at: %s", checkpoint_file) |
| break |
| |
| |
| if not checkpoint_file: |
| logger.info("[SEGMENTATION] Local model not found, downloading from HuggingFace: %s", HF_SEGMENTATION_REPO) |
| token = os.getenv("HF_TOKEN") |
| |
| try: |
| checkpoint_file = hf_hub_download( |
| repo_id=HF_SEGMENTATION_REPO, |
| filename=HF_SEGMENTATION_FILENAME, |
| token=token, |
| resume_download=True |
| ) |
| logger.info("[SEGMENTATION] Downloaded from HuggingFace: %s", checkpoint_file) |
| except Exception as e: |
| logger.error("[SEGMENTATION] Failed to download from HuggingFace: %s", e) |
| raise FileNotFoundError( |
| f"Segmentation model not found locally or on HuggingFace. " |
| f"Repo: {HF_SEGMENTATION_REPO} | Error: {e}" |
| ) |
|
|
| |
| checkpoint = torch.load(checkpoint_file, map_location=self.device) |
| state_dict = checkpoint.get("model_state_dict", checkpoint) |
| cleaned = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|
| self.model.load_state_dict(cleaned, strict=False) |
| self.model.eval() |
| logger.info("[SEGMENTATION] Model loaded successfully") |
|
|
| @torch.inference_mode() |
| def _reconstruct(self, image_tensor: torch.Tensor) -> torch.Tensor: |
| if image_tensor.ndim == 3: |
| image_tensor = image_tensor.unsqueeze(0) |
| image_tensor = image_tensor.to(self.device) |
| recon, _, _ = self.model(image_tensor) |
| return recon |
|
|
| def _fill_holes_intelligently(self, binary_mask: np.ndarray) -> np.ndarray: |
| """Fill holes inside cells intelligently based on area ratios. |
| Direct port from Kaggle code.""" |
| if np.sum(binary_mask == 0) < 100: |
| return binary_mask |
| |
| filled_mask = binary_mask.copy() |
| inverted = cv2.bitwise_not(binary_mask) |
| contours, hierarchy = cv2.findContours(inverted, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE) |
| |
| if hierarchy is None or len(contours) == 0: |
| return binary_mask |
| |
| hierarchy = hierarchy[0] |
| |
| for i in range(len(contours)): |
| if hierarchy[i][3] != -1: |
| parent_idx = hierarchy[i][3] |
| parent_contour = contours[parent_idx] |
| hole_contour = contours[i] |
| |
| parent_area = cv2.contourArea(parent_contour) |
| hole_area = cv2.contourArea(hole_contour) |
| |
| if parent_area > 50 and hole_area > 5: |
| area_ratio = hole_area / parent_area |
| |
| if area_ratio < 0.4 and hole_area < 2000: |
| cv2.drawContours(filled_mask, [hole_contour], -1, 0, -1) |
| |
| return filled_mask |
|
|
| def _make_binary_mask(self, recon_tensor: torch.Tensor) -> np.ndarray: |
| """Creates a binary mask using Otsu thresholding with hole filling. |
| Direct port from Kaggle code: create_binary_mask_from_recon.""" |
| recon_np = recon_tensor.squeeze(0).cpu().numpy().transpose(1, 2, 0) |
| recon_np = np.clip(recon_np, 0, 1) |
| recon_uint8 = (recon_np * 255).astype(np.uint8) |
|
|
| |
| gray = cv2.cvtColor(recon_uint8, cv2.COLOR_RGB2GRAY) |
| |
| |
| blurred = cv2.GaussianBlur(gray, (5, 5), 0) |
| |
| |
| _, binary_mask = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
| |
| |
| white_pixels = np.sum(binary_mask == 255) |
| total_pixels = binary_mask.size |
| white_ratio = white_pixels / total_pixels |
| |
| if white_ratio < 0.5: |
| binary_mask = cv2.bitwise_not(binary_mask) |
| |
| |
| if np.sum(binary_mask == 255) > 100: |
| filled_mask = self._fill_holes_intelligently(binary_mask) |
| else: |
| filled_mask = binary_mask |
| |
| return filled_mask |
|
|
| def _segment_small(self, image: Image.Image) -> Tuple[np.ndarray, np.ndarray]: |
| original_size = image.size |
| resized = image.resize((self.tile_size, self.tile_size), Image.Resampling.LANCZOS) |
| tensor = self.to_tensor(resized) |
| recon_tensor = self._reconstruct(tensor) |
| |
| |
| mask = self._make_binary_mask(recon_tensor) |
| |
| |
| recon_np = recon_tensor.squeeze(0).cpu().numpy().transpose(1, 2, 0) |
| recon_rgb = (np.clip(recon_np, 0, 1) * 255).astype(np.uint8) |
| |
| |
| final_mask = np.array(Image.fromarray(mask).resize(original_size, Image.Resampling.NEAREST), dtype=np.uint8) |
| final_recon = np.array(Image.fromarray(recon_rgb).resize(original_size, Image.Resampling.LANCZOS), dtype=np.uint8) |
| |
| return final_mask, final_recon |
|
|
| def _segment_large_tiled(self, image: Image.Image) -> Tuple[np.ndarray, np.ndarray]: |
| width, height = image.size |
| tile_size = self.tile_size |
|
|
| tiles_x = (width + tile_size - 1) // tile_size |
| tiles_y = (height + tile_size - 1) // tile_size |
| padded_w = tiles_x * tile_size |
| padded_h = tiles_y * tile_size |
|
|
| original_np = np.array(image) |
| mean_color = tuple(int(v) for v in np.mean(original_np, axis=(0, 1))) |
| padded = Image.new("RGB", (padded_w, padded_h), color=mean_color) |
| padded.paste(image, (0, 0)) |
|
|
| stitched_mask = np.zeros((padded_h, padded_w), dtype=np.uint8) |
| stitched_recon = np.zeros((padded_h, padded_w, 3), dtype=np.uint8) |
|
|
| for y in range(tiles_y): |
| for x in range(tiles_x): |
| x0 = x * tile_size |
| y0 = y * tile_size |
| x1 = x0 + tile_size |
| y1 = y0 + tile_size |
|
|
| tile = padded.crop((x0, y0, x1, y1)) |
| tile_tensor = self.to_tensor(tile) |
| recon_tensor = self._reconstruct(tile_tensor) |
| |
| |
| tile_mask = self._make_binary_mask(recon_tensor) |
| stitched_mask[y0:y1, x0:x1] = tile_mask |
| |
| |
| recon_np = recon_tensor.squeeze(0).cpu().numpy().transpose(1, 2, 0) |
| recon_rgb = (np.clip(recon_np, 0, 1) * 255).astype(np.uint8) |
| stitched_recon[y0:y1, x0:x1] = recon_rgb |
|
|
| return stitched_mask[:height, :width], stitched_recon[:height, :width] |
|
|
| def _estimate_image_size_bytes(self, image: Image.Image) -> int: |
| """Estimate the image file size in bytes by serializing to PNG format. |
| This gives a reasonable estimate of uncompressed image data size.""" |
| try: |
| buffer = io.BytesIO() |
| image.save(buffer, format="PNG") |
| return buffer.tell() |
| except Exception as err: |
| |
| width, height = image.size |
| logger.warning("[SEGMENTATION] Could not estimate image size: %s. Using pixel-based estimate.", err) |
| return width * height * 3 |
|
|
| def segment(self, image: Image.Image, file_size_bytes: int | None = None) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]: |
| """Segment image using tiling only if image is > 1.5 MB. |
| |
| Args: |
| image: PIL Image to segment |
| file_size_bytes: Optional file size in bytes. If not provided, estimated from image data. |
| |
| Returns: |
| Tuple of (mask array, reconstruction array, metadata dict) |
| """ |
| width, height = image.size |
| |
| |
| if file_size_bytes is None: |
| file_size_bytes = self._estimate_image_size_bytes(image) |
| |
| threshold_bytes = int(1.5 * 1024 * 1024) |
| use_tiling = file_size_bytes > threshold_bytes |
|
|
| if use_tiling: |
| mask, recon = self._segment_large_tiled(image) |
| else: |
| mask, recon = self._segment_small(image) |
|
|
| metadata = { |
| "width": width, |
| "height": height, |
| "tile_size": self.tile_size, |
| "tiling_used": use_tiling, |
| } |
| return mask, recon, metadata |
|
|
| @staticmethod |
| def mask_to_png_bytes(mask: np.ndarray) -> bytes: |
| img = Image.fromarray(mask.astype(np.uint8), mode="L") |
| buffer = io.BytesIO() |
| img.save(buffer, format="PNG") |
| return buffer.getvalue() |
|
|
|
|
| @staticmethod |
| def image_to_base64(image_np: np.ndarray) -> str: |
| """Convert numpy image (RGB) to base64 PNG string.""" |
| import base64 |
| img = Image.fromarray(image_np.astype(np.uint8)) |
| buffer = io.BytesIO() |
| img.save(buffer, format="PNG") |
| encoded = base64.b64encode(buffer.getvalue()).decode("utf-8") |
| return f"data:image/png;base64,{encoded}" |
|
|
|
|
| _segmentation_pipeline: SegmentationPipeline | None = None |
|
|
|
|
| def get_segmentation_pipeline() -> SegmentationPipeline: |
| global _segmentation_pipeline |
|
|
| if _segmentation_pipeline is None: |
| model_path = os.getenv("SEGMENTATION_MODEL_PATH", SEGMENTATION_MODEL_DEFAULT_PATH) |
| pipeline = SegmentationPipeline(model_path=model_path, tile_size=384, embedding_dim=256) |
| pipeline.load() |
| _segmentation_pipeline = pipeline |
|
|
| return _segmentation_pipeline |
|
|