CytoSight / backend /app /models /segmentation_model.py
Kaifulimaan's picture
Enhance segmentation UI layout, implement local download, and improve diagnosis typography
60479f2
"""
Segmentation model wrapper for VQ-VAE based binary mask generation.
CPU-first implementation with tile-based inference for large images.
"""
from __future__ import annotations
import io
import logging
import os
from typing import Any, Dict, Tuple
import cv2
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from transformers import AutoModel
from huggingface_hub import hf_hub_download
logger = logging.getLogger(__name__)
# HuggingFace-only model sources
HF_SEGMENTATION_REPO = "kimaan28/VQVAE_Segmentation"
HF_SEGMENTATION_FILENAME = "vqvae_best_model.pth"
SEGMENTATION_MODEL_DEFAULT_PATH = os.path.join(
os.path.dirname(__file__),
"..",
"ml_models",
"segmentation",
"vqvae_best_model.pth",
)
class ResidualBlock(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.block = nn.Sequential(
nn.ReLU(),
nn.Conv2d(dim, dim, 3, 1, 1),
nn.BatchNorm2d(dim),
nn.ReLU(),
nn.Conv2d(dim, dim, 3, 1, 1),
nn.BatchNorm2d(dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.block(x)
class VectorQuantizer(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
commitment_cost: float = 1.0,
decay: float = 0.99,
eps: float = 1e-5,
):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.commitment_cost = commitment_cost
self.decay = decay
self.eps = eps
self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
self.embedding.weight.data.uniform_(-1 / num_embeddings, 1 / num_embeddings)
self.register_buffer("cluster_size", torch.zeros(num_embeddings))
self.register_buffer("embedding_avg", self.embedding.weight.data.clone())
def forward(self, inputs: torch.Tensor):
inputs = inputs.permute(0, 2, 3, 1).contiguous()
input_shape = inputs.shape
flat_input = inputs.view(-1, self.embedding_dim)
distances = (
torch.sum(flat_input**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.matmul(flat_input, self.embedding.weight.t())
)
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=inputs.device)
encodings.scatter_(1, encoding_indices, 1)
quantized = torch.matmul(encodings, self.embedding.weight).view(input_shape)
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
q_latent_loss = F.mse_loss(quantized, inputs.detach())
loss = q_latent_loss + self.commitment_cost * e_latent_loss
quantized = inputs + (quantized - inputs).detach()
if self.training and self.decay is not None:
n = torch.sum(encodings, dim=0)
self.cluster_size.data.mul_(self.decay).add_(n, alpha=1 - self.decay)
dw = torch.matmul(encodings.t(), flat_input)
self.embedding_avg.data.mul_(self.decay).add_(dw, alpha=1 - self.decay)
normalized_avg = self.embedding_avg / (self.cluster_size.unsqueeze(1) + self.eps)
self.embedding.weight.data.copy_(normalized_avg)
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
class VQVAE(nn.Module):
def __init__(
self,
input_channels: int = 3,
hidden_dim: int = 256,
embedding_dim: int = 256,
num_embeddings: int = 4096,
commitment_cost: float = 1.0,
):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(input_channels, hidden_dim // 4, 4, 2, 1),
nn.BatchNorm2d(hidden_dim // 4),
nn.LeakyReLU(0.2),
ResidualBlock(hidden_dim // 4),
nn.Conv2d(hidden_dim // 4, hidden_dim // 2, 4, 2, 1),
nn.BatchNorm2d(hidden_dim // 2),
nn.LeakyReLU(0.2),
ResidualBlock(hidden_dim // 2),
nn.Conv2d(hidden_dim // 2, hidden_dim, 4, 2, 1),
nn.BatchNorm2d(hidden_dim),
nn.LeakyReLU(0.2),
ResidualBlock(hidden_dim),
nn.Conv2d(hidden_dim, embedding_dim, 3, 1, 1),
nn.LeakyReLU(0.2),
)
self.vq = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)
self.decoder = nn.Sequential(
nn.Conv2d(embedding_dim, hidden_dim, 3, 1, 1),
nn.BatchNorm2d(hidden_dim),
nn.LeakyReLU(0.2),
ResidualBlock(hidden_dim),
nn.ConvTranspose2d(hidden_dim, hidden_dim // 2, 4, 2, 1),
nn.BatchNorm2d(hidden_dim // 2),
nn.LeakyReLU(0.2),
ResidualBlock(hidden_dim // 2),
nn.ConvTranspose2d(hidden_dim // 2, hidden_dim // 4, 4, 2, 1),
nn.BatchNorm2d(hidden_dim // 4),
nn.LeakyReLU(0.2),
ResidualBlock(hidden_dim // 4),
nn.ConvTranspose2d(hidden_dim // 4, input_channels, 4, 2, 1),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor):
z = self.encoder(x)
vq_loss, quantized, perplexity, _ = self.vq(z)
recon_x = self.decoder(quantized)
return recon_x, vq_loss, perplexity
class SegmentationPipeline:
def __init__(self, model_path: str, tile_size: int = 384, embedding_dim: int = 256):
self.model_path = os.path.abspath(model_path)
self.tile_size = tile_size
self.embedding_dim = embedding_dim
self.device = torch.device("cpu")
self.model = VQVAE(
input_channels=3,
hidden_dim=256,
embedding_dim=self.embedding_dim,
num_embeddings=4096,
commitment_cost=1.0,
).to(self.device)
self.to_tensor = transforms.ToTensor()
def load(self) -> None:
"""Load model weights from local path or HuggingFace Hub."""
checkpoint_file = None
# 1. Try local paths first
local_paths = [
self.model_path,
SEGMENTATION_MODEL_DEFAULT_PATH,
"./ml_models/segmentation/vqvae_best_model.pth",
"../ml_models/segmentation/vqvae_best_model.pth",
]
for path in local_paths:
if os.path.exists(path):
checkpoint_file = os.path.abspath(path)
logger.info("[SEGMENTATION] Found local model at: %s", checkpoint_file)
break
# 2. Fallback to HuggingFace
if not checkpoint_file:
logger.info("[SEGMENTATION] Local model not found, downloading from HuggingFace: %s", HF_SEGMENTATION_REPO)
token = os.getenv("HF_TOKEN")
try:
checkpoint_file = hf_hub_download(
repo_id=HF_SEGMENTATION_REPO,
filename=HF_SEGMENTATION_FILENAME,
token=token,
resume_download=True
)
logger.info("[SEGMENTATION] Downloaded from HuggingFace: %s", checkpoint_file)
except Exception as e:
logger.error("[SEGMENTATION] Failed to download from HuggingFace: %s", e)
raise FileNotFoundError(
f"Segmentation model not found locally or on HuggingFace. "
f"Repo: {HF_SEGMENTATION_REPO} | Error: {e}"
)
# 3. Load the weights
checkpoint = torch.load(checkpoint_file, map_location=self.device)
state_dict = checkpoint.get("model_state_dict", checkpoint)
cleaned = {k.replace("module.", ""): v for k, v in state_dict.items()}
self.model.load_state_dict(cleaned, strict=False)
self.model.eval()
logger.info("[SEGMENTATION] Model loaded successfully")
@torch.inference_mode()
def _reconstruct(self, image_tensor: torch.Tensor) -> torch.Tensor:
if image_tensor.ndim == 3:
image_tensor = image_tensor.unsqueeze(0)
image_tensor = image_tensor.to(self.device)
recon, _, _ = self.model(image_tensor)
return recon
def _fill_holes_intelligently(self, binary_mask: np.ndarray) -> np.ndarray:
"""Fill holes inside cells intelligently based on area ratios.
Direct port from Kaggle code."""
if np.sum(binary_mask == 0) < 100:
return binary_mask
filled_mask = binary_mask.copy()
inverted = cv2.bitwise_not(binary_mask)
contours, hierarchy = cv2.findContours(inverted, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
if hierarchy is None or len(contours) == 0:
return binary_mask
hierarchy = hierarchy[0]
for i in range(len(contours)):
if hierarchy[i][3] != -1:
parent_idx = hierarchy[i][3]
parent_contour = contours[parent_idx]
hole_contour = contours[i]
parent_area = cv2.contourArea(parent_contour)
hole_area = cv2.contourArea(hole_contour)
if parent_area > 50 and hole_area > 5:
area_ratio = hole_area / parent_area
if area_ratio < 0.4 and hole_area < 2000:
cv2.drawContours(filled_mask, [hole_contour], -1, 0, -1)
return filled_mask
def _make_binary_mask(self, recon_tensor: torch.Tensor) -> np.ndarray:
"""Creates a binary mask using Otsu thresholding with hole filling.
Direct port from Kaggle code: create_binary_mask_from_recon."""
recon_np = recon_tensor.squeeze(0).cpu().numpy().transpose(1, 2, 0)
recon_np = np.clip(recon_np, 0, 1)
recon_uint8 = (recon_np * 255).astype(np.uint8)
# Convert to grayscale
gray = cv2.cvtColor(recon_uint8, cv2.COLOR_RGB2GRAY)
# Gaussian Blur
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
# Otsu thresholding
_, binary_mask = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# White ratio inversion logic
white_pixels = np.sum(binary_mask == 255)
total_pixels = binary_mask.size
white_ratio = white_pixels / total_pixels
if white_ratio < 0.5:
binary_mask = cv2.bitwise_not(binary_mask)
# Intelligent hole filling
if np.sum(binary_mask == 255) > 100:
filled_mask = self._fill_holes_intelligently(binary_mask)
else:
filled_mask = binary_mask
return filled_mask
def _segment_small(self, image: Image.Image) -> Tuple[np.ndarray, np.ndarray]:
original_size = image.size
resized = image.resize((self.tile_size, self.tile_size), Image.Resampling.LANCZOS)
tensor = self.to_tensor(resized)
recon_tensor = self._reconstruct(tensor)
# Binary mask from reconstruction
mask = self._make_binary_mask(recon_tensor)
# Reconstructed RGB image
recon_np = recon_tensor.squeeze(0).cpu().numpy().transpose(1, 2, 0)
recon_rgb = (np.clip(recon_np, 0, 1) * 255).astype(np.uint8)
# Resize both back to original size
final_mask = np.array(Image.fromarray(mask).resize(original_size, Image.Resampling.NEAREST), dtype=np.uint8)
final_recon = np.array(Image.fromarray(recon_rgb).resize(original_size, Image.Resampling.LANCZOS), dtype=np.uint8)
return final_mask, final_recon
def _segment_large_tiled(self, image: Image.Image) -> Tuple[np.ndarray, np.ndarray]:
width, height = image.size
tile_size = self.tile_size
tiles_x = (width + tile_size - 1) // tile_size
tiles_y = (height + tile_size - 1) // tile_size
padded_w = tiles_x * tile_size
padded_h = tiles_y * tile_size
original_np = np.array(image)
mean_color = tuple(int(v) for v in np.mean(original_np, axis=(0, 1)))
padded = Image.new("RGB", (padded_w, padded_h), color=mean_color)
padded.paste(image, (0, 0))
stitched_mask = np.zeros((padded_h, padded_w), dtype=np.uint8)
stitched_recon = np.zeros((padded_h, padded_w, 3), dtype=np.uint8)
for y in range(tiles_y):
for x in range(tiles_x):
x0 = x * tile_size
y0 = y * tile_size
x1 = x0 + tile_size
y1 = y0 + tile_size
tile = padded.crop((x0, y0, x1, y1))
tile_tensor = self.to_tensor(tile)
recon_tensor = self._reconstruct(tile_tensor)
# Mask
tile_mask = self._make_binary_mask(recon_tensor)
stitched_mask[y0:y1, x0:x1] = tile_mask
# Recon
recon_np = recon_tensor.squeeze(0).cpu().numpy().transpose(1, 2, 0)
recon_rgb = (np.clip(recon_np, 0, 1) * 255).astype(np.uint8)
stitched_recon[y0:y1, x0:x1] = recon_rgb
return stitched_mask[:height, :width], stitched_recon[:height, :width]
def _estimate_image_size_bytes(self, image: Image.Image) -> int:
"""Estimate the image file size in bytes by serializing to PNG format.
This gives a reasonable estimate of uncompressed image data size."""
try:
buffer = io.BytesIO()
image.save(buffer, format="PNG")
return buffer.tell()
except Exception as err:
# Fallback: estimate based on raw pixel data (W * H * 3 channels)
width, height = image.size
logger.warning("[SEGMENTATION] Could not estimate image size: %s. Using pixel-based estimate.", err)
return width * height * 3
def segment(self, image: Image.Image, file_size_bytes: int | None = None) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]:
"""Segment image using tiling only if image is > 1.5 MB.
Args:
image: PIL Image to segment
file_size_bytes: Optional file size in bytes. If not provided, estimated from image data.
Returns:
Tuple of (mask array, reconstruction array, metadata dict)
"""
width, height = image.size
# Determine if tiling should be used based on image file size (1.5 MB threshold)
if file_size_bytes is None:
file_size_bytes = self._estimate_image_size_bytes(image)
threshold_bytes = int(1.5 * 1024 * 1024) # 1.5 MB
use_tiling = file_size_bytes > threshold_bytes
if use_tiling:
mask, recon = self._segment_large_tiled(image)
else:
mask, recon = self._segment_small(image)
metadata = {
"width": width,
"height": height,
"tile_size": self.tile_size,
"tiling_used": use_tiling,
}
return mask, recon, metadata
@staticmethod
def mask_to_png_bytes(mask: np.ndarray) -> bytes:
img = Image.fromarray(mask.astype(np.uint8), mode="L")
buffer = io.BytesIO()
img.save(buffer, format="PNG")
return buffer.getvalue()
@staticmethod
def image_to_base64(image_np: np.ndarray) -> str:
"""Convert numpy image (RGB) to base64 PNG string."""
import base64
img = Image.fromarray(image_np.astype(np.uint8))
buffer = io.BytesIO()
img.save(buffer, format="PNG")
encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
return f"data:image/png;base64,{encoded}"
_segmentation_pipeline: SegmentationPipeline | None = None
def get_segmentation_pipeline() -> SegmentationPipeline:
global _segmentation_pipeline
if _segmentation_pipeline is None:
model_path = os.getenv("SEGMENTATION_MODEL_PATH", SEGMENTATION_MODEL_DEFAULT_PATH)
pipeline = SegmentationPipeline(model_path=model_path, tile_size=384, embedding_dim=256)
pipeline.load()
_segmentation_pipeline = pipeline
return _segmentation_pipeline