CheXmask-U / utils /segmentation.py
mcosarinsky's picture
update
b698ace
import numpy as np
import cv2
import torch
import scipy.sparse as sp
import sys
import os
import random
from zipfile import ZipFile
from .plotting import plot_side_by_side_comparison
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.HybridGNet2IGSC import HybridGNetHF
hybrid = None
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def zip_files(files, output_name="complete_results.zip"):
with ZipFile(output_name, "w") as zipObj:
for file in files:
zipObj.write(file, arcname=file.split("/")[-1])
return output_name
def getMasks(landmarks, h, w):
RL, LL, H = landmarks[:44], landmarks[44:94], landmarks[94:]
RL_mask, LL_mask, H_mask = [np.zeros([h, w], dtype='uint8') for _ in range(3)]
RL_mask = cv2.drawContours(RL_mask, [RL.reshape(-1,1,2).astype('int')], -1, 255, -1)
LL_mask = cv2.drawContours(LL_mask, [LL.reshape(-1,1,2).astype('int')], -1, 255, -1)
H_mask = cv2.drawContours(H_mask, [H.reshape(-1,1,2).astype('int')], -1, 255, -1)
return RL_mask, LL_mask, H_mask
def pad_to_square(img):
h, w = img.shape[:2]
if h > w:
padw = h - w
auxw = padw % 2
img = np.pad(img, ((0,0),(padw//2, padw//2+auxw)), 'constant')
return img, (0, padw, 0, auxw)
else:
padh = w - h
auxh = padh % 2
img = np.pad(img, ((padh//2, padh//2+auxh),(0,0)), 'constant')
return img, (padh, 0, auxh, 0)
def preprocess(img):
img, padding = pad_to_square(img)
h, w = img.shape[:2]
if h != 1024 or w != 1024:
img = cv2.resize(img, (1024,1024), interpolation=cv2.INTER_CUBIC)
return img, (h, w, padding)
def removePreprocess(output, info):
h, w, padding = info
padh, padw, auxh, auxw = padding
if h != 1024 or w != 1024:
output = output * h
else:
output = output * 1024
output[:,:,0] -= padw//2
output[:,:,1] -= padh//2
return output
def loadModel(device):
global hybrid
hybrid = HybridGNetHF.from_pretrained(
repo_id="mcosarinsky/CheXmask-U",
subfolder="v1_skip",
device=device
)
hybrid.eval()
return hybrid
def predict_landmarks(img, n_samples=100):
global hybrid
img_proc, (h, w, padding) = preprocess(img)
data = torch.from_numpy(img_proc).unsqueeze(0).unsqueeze(0).to(next(hybrid.parameters()).device).float()
with torch.no_grad():
mu, log_var, conv6, conv5 = hybrid.encode(data)
zs = [hybrid.sampling(mu, log_var) for _ in range(n_samples)]
z_exp = torch.stack(zs, dim=0)
conv6_exp, conv5_exp = conv6.repeat(n_samples,1,1,1), conv5.repeat(n_samples,1,1,1)
output, _, _ = hybrid.decode(z_exp, conv6_exp, conv5_exp)
output = output.cpu().numpy().reshape(n_samples,-1,2)
output = removePreprocess(output, (h,w,padding)).astype('int')
means, stds = np.mean(output,axis=0), np.std(output,axis=0)
return means, stds
def segment(input_img, noise_std=0.0):
"""
input_img: dict with keys "image" (numpy array) and optionally "mask"
noise_std: standard deviation of Gaussian noise to add for robustness
Returns: path to comparison figure, list of saved files
"""
global hybrid
if hybrid is None:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
hybrid = loadModel(device)
# Original image and corrupted version
img_orig = input_img["image"].astype(np.float32) / 255.0
mask = input_img.get("mask", None)
if mask is not None:
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0
mask = 1.0 - mask
img_corr = np.minimum(img_orig, mask)
else:
img_corr = img_orig.copy()
if noise_std > 0:
noise = np.random.normal(0, noise_std, img_corr.shape)
img_corr = np.clip(img_corr + noise, 0.0, 1.0)
# Predict landmarks
seed_everything(123)
means_orig, stds_orig = predict_landmarks(img_orig)
seed_everything(123)
means_corr, stds_corr = predict_landmarks(img_corr)
# Save landmarks and masks
os.makedirs("tmp", exist_ok=True)
RL, LL, H = means_orig[:44], means_orig[44:94], means_orig[94:]
np.savetxt("tmp/RL_landmarks.txt", RL, delimiter=" ", fmt="%d")
np.savetxt("tmp/LL_landmarks.txt", LL, delimiter=" ", fmt="%d")
np.savetxt("tmp/H_landmarks.txt", H, delimiter=" ", fmt="%d")
RL_mask, LL_mask, H_mask = getMasks(means_orig, img_orig.shape[0], img_orig.shape[1])
cv2.imwrite("tmp/RL_mask.png", RL_mask)
cv2.imwrite("tmp/LL_mask.png", LL_mask)
cv2.imwrite("tmp/H_mask.png", H_mask)
RL_std, LL_std, H_std = stds_orig[:44], stds_orig[44:94], stds_orig[94:]
np.savetxt("tmp/RL_std.txt", RL_std, delimiter=" ", fmt="%.4f")
np.savetxt("tmp/LL_std.txt", LL_std, delimiter=" ", fmt="%.4f")
np.savetxt("tmp/H_std.txt", H_std, delimiter=" ", fmt="%.4f")
zipf = zip_files([
"tmp/RL_landmarks.txt","tmp/LL_landmarks.txt","tmp/H_landmarks.txt",
"tmp/RL_mask.png","tmp/LL_mask.png","tmp/H_mask.png",
"tmp/RL_std.txt","tmp/LL_std.txt","tmp/H_std.txt"
])
# Optional: plot side-by-side comparison
fig = plot_side_by_side_comparison(img_orig, means_orig, stds_orig, img_corr, means_corr, stds_corr)
output_path = "tmp/segmentation_comparison.png"
fig.savefig(output_path, dpi=300)
import matplotlib.pyplot as plt
plt.close(fig)
saved_files = [
"tmp/RL_landmarks.txt","tmp/LL_landmarks.txt","tmp/H_landmarks.txt",
"tmp/RL_mask.png","tmp/LL_mask.png","tmp/H_mask.png",
"tmp/RL_std.txt","tmp/LL_std.txt","tmp/H_std.txt",
zipf
]
return output_path, saved_files