Spaces:
Running
Running
| import numpy as np | |
| import cv2 | |
| import torch | |
| import scipy.sparse as sp | |
| import sys | |
| import os | |
| import random | |
| from zipfile import ZipFile | |
| from .plotting import plot_side_by_side_comparison | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from models.HybridGNet2IGSC import HybridGNetHF | |
| hybrid = None | |
| def seed_everything(seed=42): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| def zip_files(files, output_name="complete_results.zip"): | |
| with ZipFile(output_name, "w") as zipObj: | |
| for file in files: | |
| zipObj.write(file, arcname=file.split("/")[-1]) | |
| return output_name | |
| def getMasks(landmarks, h, w): | |
| RL, LL, H = landmarks[:44], landmarks[44:94], landmarks[94:] | |
| RL_mask, LL_mask, H_mask = [np.zeros([h, w], dtype='uint8') for _ in range(3)] | |
| RL_mask = cv2.drawContours(RL_mask, [RL.reshape(-1,1,2).astype('int')], -1, 255, -1) | |
| LL_mask = cv2.drawContours(LL_mask, [LL.reshape(-1,1,2).astype('int')], -1, 255, -1) | |
| H_mask = cv2.drawContours(H_mask, [H.reshape(-1,1,2).astype('int')], -1, 255, -1) | |
| return RL_mask, LL_mask, H_mask | |
| def pad_to_square(img): | |
| h, w = img.shape[:2] | |
| if h > w: | |
| padw = h - w | |
| auxw = padw % 2 | |
| img = np.pad(img, ((0,0),(padw//2, padw//2+auxw)), 'constant') | |
| return img, (0, padw, 0, auxw) | |
| else: | |
| padh = w - h | |
| auxh = padh % 2 | |
| img = np.pad(img, ((padh//2, padh//2+auxh),(0,0)), 'constant') | |
| return img, (padh, 0, auxh, 0) | |
| def preprocess(img): | |
| img, padding = pad_to_square(img) | |
| h, w = img.shape[:2] | |
| if h != 1024 or w != 1024: | |
| img = cv2.resize(img, (1024,1024), interpolation=cv2.INTER_CUBIC) | |
| return img, (h, w, padding) | |
| def removePreprocess(output, info): | |
| h, w, padding = info | |
| padh, padw, auxh, auxw = padding | |
| if h != 1024 or w != 1024: | |
| output = output * h | |
| else: | |
| output = output * 1024 | |
| output[:,:,0] -= padw//2 | |
| output[:,:,1] -= padh//2 | |
| return output | |
| def loadModel(device): | |
| global hybrid | |
| hybrid = HybridGNetHF.from_pretrained( | |
| repo_id="mcosarinsky/CheXmask-U", | |
| subfolder="v1_skip", | |
| device=device | |
| ) | |
| hybrid.eval() | |
| return hybrid | |
| def predict_landmarks(img, n_samples=100): | |
| global hybrid | |
| img_proc, (h, w, padding) = preprocess(img) | |
| data = torch.from_numpy(img_proc).unsqueeze(0).unsqueeze(0).to(next(hybrid.parameters()).device).float() | |
| with torch.no_grad(): | |
| mu, log_var, conv6, conv5 = hybrid.encode(data) | |
| zs = [hybrid.sampling(mu, log_var) for _ in range(n_samples)] | |
| z_exp = torch.stack(zs, dim=0) | |
| conv6_exp, conv5_exp = conv6.repeat(n_samples,1,1,1), conv5.repeat(n_samples,1,1,1) | |
| output, _, _ = hybrid.decode(z_exp, conv6_exp, conv5_exp) | |
| output = output.cpu().numpy().reshape(n_samples,-1,2) | |
| output = removePreprocess(output, (h,w,padding)).astype('int') | |
| means, stds = np.mean(output,axis=0), np.std(output,axis=0) | |
| return means, stds | |
| def segment(input_img, noise_std=0.0): | |
| """ | |
| input_img: dict with keys "image" (numpy array) and optionally "mask" | |
| noise_std: standard deviation of Gaussian noise to add for robustness | |
| Returns: path to comparison figure, list of saved files | |
| """ | |
| global hybrid | |
| if hybrid is None: | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| hybrid = loadModel(device) | |
| # Original image and corrupted version | |
| img_orig = input_img["image"].astype(np.float32) / 255.0 | |
| mask = input_img.get("mask", None) | |
| if mask is not None: | |
| mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0 | |
| mask = 1.0 - mask | |
| img_corr = np.minimum(img_orig, mask) | |
| else: | |
| img_corr = img_orig.copy() | |
| if noise_std > 0: | |
| noise = np.random.normal(0, noise_std, img_corr.shape) | |
| img_corr = np.clip(img_corr + noise, 0.0, 1.0) | |
| # Predict landmarks | |
| seed_everything(123) | |
| means_orig, stds_orig = predict_landmarks(img_orig) | |
| seed_everything(123) | |
| means_corr, stds_corr = predict_landmarks(img_corr) | |
| # Save landmarks and masks | |
| os.makedirs("tmp", exist_ok=True) | |
| RL, LL, H = means_orig[:44], means_orig[44:94], means_orig[94:] | |
| np.savetxt("tmp/RL_landmarks.txt", RL, delimiter=" ", fmt="%d") | |
| np.savetxt("tmp/LL_landmarks.txt", LL, delimiter=" ", fmt="%d") | |
| np.savetxt("tmp/H_landmarks.txt", H, delimiter=" ", fmt="%d") | |
| RL_mask, LL_mask, H_mask = getMasks(means_orig, img_orig.shape[0], img_orig.shape[1]) | |
| cv2.imwrite("tmp/RL_mask.png", RL_mask) | |
| cv2.imwrite("tmp/LL_mask.png", LL_mask) | |
| cv2.imwrite("tmp/H_mask.png", H_mask) | |
| RL_std, LL_std, H_std = stds_orig[:44], stds_orig[44:94], stds_orig[94:] | |
| np.savetxt("tmp/RL_std.txt", RL_std, delimiter=" ", fmt="%.4f") | |
| np.savetxt("tmp/LL_std.txt", LL_std, delimiter=" ", fmt="%.4f") | |
| np.savetxt("tmp/H_std.txt", H_std, delimiter=" ", fmt="%.4f") | |
| zipf = zip_files([ | |
| "tmp/RL_landmarks.txt","tmp/LL_landmarks.txt","tmp/H_landmarks.txt", | |
| "tmp/RL_mask.png","tmp/LL_mask.png","tmp/H_mask.png", | |
| "tmp/RL_std.txt","tmp/LL_std.txt","tmp/H_std.txt" | |
| ]) | |
| # Optional: plot side-by-side comparison | |
| fig = plot_side_by_side_comparison(img_orig, means_orig, stds_orig, img_corr, means_corr, stds_corr) | |
| output_path = "tmp/segmentation_comparison.png" | |
| fig.savefig(output_path, dpi=300) | |
| import matplotlib.pyplot as plt | |
| plt.close(fig) | |
| saved_files = [ | |
| "tmp/RL_landmarks.txt","tmp/LL_landmarks.txt","tmp/H_landmarks.txt", | |
| "tmp/RL_mask.png","tmp/LL_mask.png","tmp/H_mask.png", | |
| "tmp/RL_std.txt","tmp/LL_std.txt","tmp/H_std.txt", | |
| zipf | |
| ] | |
| return output_path, saved_files | |