File size: 5,875 Bytes
4e79318
 
 
 
 
 
edb6fcc
4e79318
 
 
 
b698ace
4e79318
 
 
edb6fcc
 
 
 
 
 
 
4e79318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b698ace
 
 
 
 
4e79318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edb6fcc
4e79318
edb6fcc
4e79318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import numpy as np
import cv2
import torch
import scipy.sparse as sp
import sys
import os
import random
from zipfile import ZipFile
from .plotting import plot_side_by_side_comparison

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.HybridGNet2IGSC import HybridGNetHF

hybrid = None

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def zip_files(files, output_name="complete_results.zip"):
    with ZipFile(output_name, "w") as zipObj:
        for file in files:
            zipObj.write(file, arcname=file.split("/")[-1])
    return output_name

def getMasks(landmarks, h, w):
    RL, LL, H = landmarks[:44], landmarks[44:94], landmarks[94:]
    RL_mask, LL_mask, H_mask = [np.zeros([h, w], dtype='uint8') for _ in range(3)]
    RL_mask = cv2.drawContours(RL_mask, [RL.reshape(-1,1,2).astype('int')], -1, 255, -1)
    LL_mask = cv2.drawContours(LL_mask, [LL.reshape(-1,1,2).astype('int')], -1, 255, -1)
    H_mask = cv2.drawContours(H_mask, [H.reshape(-1,1,2).astype('int')], -1, 255, -1)
    return RL_mask, LL_mask, H_mask

def pad_to_square(img):
    h, w = img.shape[:2]
    if h > w:
        padw = h - w
        auxw = padw % 2
        img = np.pad(img, ((0,0),(padw//2, padw//2+auxw)), 'constant')
        return img, (0, padw, 0, auxw)
    else:
        padh = w - h
        auxh = padh % 2
        img = np.pad(img, ((padh//2, padh//2+auxh),(0,0)), 'constant')
        return img, (padh, 0, auxh, 0)

def preprocess(img):
    img, padding = pad_to_square(img)
    h, w = img.shape[:2]
    if h != 1024 or w != 1024:
        img = cv2.resize(img, (1024,1024), interpolation=cv2.INTER_CUBIC)
    return img, (h, w, padding)

def removePreprocess(output, info):
    h, w, padding = info
    padh, padw, auxh, auxw = padding
    if h != 1024 or w != 1024:
        output = output * h
    else:
        output = output * 1024
    output[:,:,0] -= padw//2
    output[:,:,1] -= padh//2
    return output

def loadModel(device):    
    global hybrid
    hybrid = HybridGNetHF.from_pretrained(
            repo_id="mcosarinsky/CheXmask-U",
            subfolder="v1_skip",
            device=device
            )
    hybrid.eval()
    return hybrid

def predict_landmarks(img, n_samples=100):
    global hybrid
    img_proc, (h, w, padding) = preprocess(img)
    data = torch.from_numpy(img_proc).unsqueeze(0).unsqueeze(0).to(next(hybrid.parameters()).device).float()
    with torch.no_grad():
        mu, log_var, conv6, conv5 = hybrid.encode(data)
        zs = [hybrid.sampling(mu, log_var) for _ in range(n_samples)]
        z_exp = torch.stack(zs, dim=0)
        conv6_exp, conv5_exp = conv6.repeat(n_samples,1,1,1), conv5.repeat(n_samples,1,1,1)
        output, _, _ = hybrid.decode(z_exp, conv6_exp, conv5_exp)
        output = output.cpu().numpy().reshape(n_samples,-1,2)
        output = removePreprocess(output, (h,w,padding)).astype('int')
        means, stds = np.mean(output,axis=0), np.std(output,axis=0)
    return means, stds


def segment(input_img, noise_std=0.0):
    """
    input_img: dict with keys "image" (numpy array) and optionally "mask"
    noise_std: standard deviation of Gaussian noise to add for robustness
    Returns: path to comparison figure, list of saved files
    """
    global hybrid

    if hybrid is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        hybrid = loadModel(device)

    # Original image and corrupted version
    img_orig = input_img["image"].astype(np.float32) / 255.0
    mask = input_img.get("mask", None)
    if mask is not None:
        mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0
        mask = 1.0 - mask
        img_corr = np.minimum(img_orig, mask)
    else:
        img_corr = img_orig.copy()

    if noise_std > 0:
        noise = np.random.normal(0, noise_std, img_corr.shape)
        img_corr = np.clip(img_corr + noise, 0.0, 1.0)

    # Predict landmarks
    seed_everything(123)
    means_orig, stds_orig = predict_landmarks(img_orig)
    seed_everything(123)
    means_corr, stds_corr = predict_landmarks(img_corr)

    # Save landmarks and masks
    os.makedirs("tmp", exist_ok=True)

    RL, LL, H = means_orig[:44], means_orig[44:94], means_orig[94:]
    np.savetxt("tmp/RL_landmarks.txt", RL, delimiter=" ", fmt="%d")
    np.savetxt("tmp/LL_landmarks.txt", LL, delimiter=" ", fmt="%d")
    np.savetxt("tmp/H_landmarks.txt", H, delimiter=" ", fmt="%d")

    RL_mask, LL_mask, H_mask = getMasks(means_orig, img_orig.shape[0], img_orig.shape[1])
    cv2.imwrite("tmp/RL_mask.png", RL_mask)
    cv2.imwrite("tmp/LL_mask.png", LL_mask)
    cv2.imwrite("tmp/H_mask.png", H_mask)

    RL_std, LL_std, H_std = stds_orig[:44], stds_orig[44:94], stds_orig[94:]
    np.savetxt("tmp/RL_std.txt", RL_std, delimiter=" ", fmt="%.4f")
    np.savetxt("tmp/LL_std.txt", LL_std, delimiter=" ", fmt="%.4f")
    np.savetxt("tmp/H_std.txt", H_std, delimiter=" ", fmt="%.4f")

    zipf = zip_files([
        "tmp/RL_landmarks.txt","tmp/LL_landmarks.txt","tmp/H_landmarks.txt",
        "tmp/RL_mask.png","tmp/LL_mask.png","tmp/H_mask.png",
        "tmp/RL_std.txt","tmp/LL_std.txt","tmp/H_std.txt"
    ])

    # Optional: plot side-by-side comparison
    fig = plot_side_by_side_comparison(img_orig, means_orig, stds_orig, img_corr, means_corr, stds_corr)
    output_path = "tmp/segmentation_comparison.png"
    fig.savefig(output_path, dpi=300)
    import matplotlib.pyplot as plt
    plt.close(fig)

    saved_files = [
        "tmp/RL_landmarks.txt","tmp/LL_landmarks.txt","tmp/H_landmarks.txt",
        "tmp/RL_mask.png","tmp/LL_mask.png","tmp/H_mask.png",
        "tmp/RL_std.txt","tmp/LL_std.txt","tmp/H_std.txt",
        zipf
    ]

    return output_path, saved_files