Spaces:
Running
Running
File size: 5,875 Bytes
4e79318 edb6fcc 4e79318 b698ace 4e79318 edb6fcc 4e79318 b698ace 4e79318 edb6fcc 4e79318 edb6fcc 4e79318 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import numpy as np
import cv2
import torch
import scipy.sparse as sp
import sys
import os
import random
from zipfile import ZipFile
from .plotting import plot_side_by_side_comparison
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.HybridGNet2IGSC import HybridGNetHF
hybrid = None
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def zip_files(files, output_name="complete_results.zip"):
with ZipFile(output_name, "w") as zipObj:
for file in files:
zipObj.write(file, arcname=file.split("/")[-1])
return output_name
def getMasks(landmarks, h, w):
RL, LL, H = landmarks[:44], landmarks[44:94], landmarks[94:]
RL_mask, LL_mask, H_mask = [np.zeros([h, w], dtype='uint8') for _ in range(3)]
RL_mask = cv2.drawContours(RL_mask, [RL.reshape(-1,1,2).astype('int')], -1, 255, -1)
LL_mask = cv2.drawContours(LL_mask, [LL.reshape(-1,1,2).astype('int')], -1, 255, -1)
H_mask = cv2.drawContours(H_mask, [H.reshape(-1,1,2).astype('int')], -1, 255, -1)
return RL_mask, LL_mask, H_mask
def pad_to_square(img):
h, w = img.shape[:2]
if h > w:
padw = h - w
auxw = padw % 2
img = np.pad(img, ((0,0),(padw//2, padw//2+auxw)), 'constant')
return img, (0, padw, 0, auxw)
else:
padh = w - h
auxh = padh % 2
img = np.pad(img, ((padh//2, padh//2+auxh),(0,0)), 'constant')
return img, (padh, 0, auxh, 0)
def preprocess(img):
img, padding = pad_to_square(img)
h, w = img.shape[:2]
if h != 1024 or w != 1024:
img = cv2.resize(img, (1024,1024), interpolation=cv2.INTER_CUBIC)
return img, (h, w, padding)
def removePreprocess(output, info):
h, w, padding = info
padh, padw, auxh, auxw = padding
if h != 1024 or w != 1024:
output = output * h
else:
output = output * 1024
output[:,:,0] -= padw//2
output[:,:,1] -= padh//2
return output
def loadModel(device):
global hybrid
hybrid = HybridGNetHF.from_pretrained(
repo_id="mcosarinsky/CheXmask-U",
subfolder="v1_skip",
device=device
)
hybrid.eval()
return hybrid
def predict_landmarks(img, n_samples=100):
global hybrid
img_proc, (h, w, padding) = preprocess(img)
data = torch.from_numpy(img_proc).unsqueeze(0).unsqueeze(0).to(next(hybrid.parameters()).device).float()
with torch.no_grad():
mu, log_var, conv6, conv5 = hybrid.encode(data)
zs = [hybrid.sampling(mu, log_var) for _ in range(n_samples)]
z_exp = torch.stack(zs, dim=0)
conv6_exp, conv5_exp = conv6.repeat(n_samples,1,1,1), conv5.repeat(n_samples,1,1,1)
output, _, _ = hybrid.decode(z_exp, conv6_exp, conv5_exp)
output = output.cpu().numpy().reshape(n_samples,-1,2)
output = removePreprocess(output, (h,w,padding)).astype('int')
means, stds = np.mean(output,axis=0), np.std(output,axis=0)
return means, stds
def segment(input_img, noise_std=0.0):
"""
input_img: dict with keys "image" (numpy array) and optionally "mask"
noise_std: standard deviation of Gaussian noise to add for robustness
Returns: path to comparison figure, list of saved files
"""
global hybrid
if hybrid is None:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
hybrid = loadModel(device)
# Original image and corrupted version
img_orig = input_img["image"].astype(np.float32) / 255.0
mask = input_img.get("mask", None)
if mask is not None:
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0
mask = 1.0 - mask
img_corr = np.minimum(img_orig, mask)
else:
img_corr = img_orig.copy()
if noise_std > 0:
noise = np.random.normal(0, noise_std, img_corr.shape)
img_corr = np.clip(img_corr + noise, 0.0, 1.0)
# Predict landmarks
seed_everything(123)
means_orig, stds_orig = predict_landmarks(img_orig)
seed_everything(123)
means_corr, stds_corr = predict_landmarks(img_corr)
# Save landmarks and masks
os.makedirs("tmp", exist_ok=True)
RL, LL, H = means_orig[:44], means_orig[44:94], means_orig[94:]
np.savetxt("tmp/RL_landmarks.txt", RL, delimiter=" ", fmt="%d")
np.savetxt("tmp/LL_landmarks.txt", LL, delimiter=" ", fmt="%d")
np.savetxt("tmp/H_landmarks.txt", H, delimiter=" ", fmt="%d")
RL_mask, LL_mask, H_mask = getMasks(means_orig, img_orig.shape[0], img_orig.shape[1])
cv2.imwrite("tmp/RL_mask.png", RL_mask)
cv2.imwrite("tmp/LL_mask.png", LL_mask)
cv2.imwrite("tmp/H_mask.png", H_mask)
RL_std, LL_std, H_std = stds_orig[:44], stds_orig[44:94], stds_orig[94:]
np.savetxt("tmp/RL_std.txt", RL_std, delimiter=" ", fmt="%.4f")
np.savetxt("tmp/LL_std.txt", LL_std, delimiter=" ", fmt="%.4f")
np.savetxt("tmp/H_std.txt", H_std, delimiter=" ", fmt="%.4f")
zipf = zip_files([
"tmp/RL_landmarks.txt","tmp/LL_landmarks.txt","tmp/H_landmarks.txt",
"tmp/RL_mask.png","tmp/LL_mask.png","tmp/H_mask.png",
"tmp/RL_std.txt","tmp/LL_std.txt","tmp/H_std.txt"
])
# Optional: plot side-by-side comparison
fig = plot_side_by_side_comparison(img_orig, means_orig, stds_orig, img_corr, means_corr, stds_corr)
output_path = "tmp/segmentation_comparison.png"
fig.savefig(output_path, dpi=300)
import matplotlib.pyplot as plt
plt.close(fig)
saved_files = [
"tmp/RL_landmarks.txt","tmp/LL_landmarks.txt","tmp/H_landmarks.txt",
"tmp/RL_mask.png","tmp/LL_mask.png","tmp/H_mask.png",
"tmp/RL_std.txt","tmp/LL_std.txt","tmp/H_std.txt",
zipf
]
return output_path, saved_files
|