File size: 6,328 Bytes
fa32eee
b8fa1f9
70ab9ba
fa32eee
b8fa1f9
fa32eee
70ab9ba
 
5cf9982
6b15e42
 
5cf9982
70ab9ba
fa32eee
 
efd13e5
 
 
5cf9982
9effdfb
0ea2406
dd2b8ab
5c8a1fa
fa32eee
 
 
6b15e42
5cf9982
 
83926b9
5cf9982
 
1fa2488
5cf9982
fa32eee
c698f72
43e729e
c4d1f2d
5cf9982
fa32eee
 
738bc7d
fa32eee
1fa2488
7616fd2
06da229
022b93a
be88075
fa32eee
 
3dc02c6
 
fa32eee
 
 
 
d4040be
0827ffe
d4040be
 
0779132
7891692
 
 
6b15e42
b0b2f83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18cdc5f
e8e6d2c
536c5aa
b0b2f83
18cdc5f
4b254cd
 
 
f44bae5
b0b2f83
f44bae5
 
b78510d
fa32eee
 
 
3981489
cee233d
717a321
c6996d3
717a321
 
fa32eee
cee233d
 
717a321
 
 
 
 
 
 
 
 
 
 
 
c5aa05d
cee233d
717a321
 
cee233d
 
 
 
 
 
 
717a321
 
 
 
 
 
 
 
cee233d
 
717a321
 
 
cee233d
 
 
717a321
 
 
 
cee233d
717a321
 
 
 
 
cee233d
717a321
c5aa05d
6b15e42
5cf9982
 
b0b2f83
 
 
 
 
 
 
0557153
b78510d
 
5f9b7fc
 
 
 
e9f9ced
5f9b7fc
e9f9ced
5f9b7fc
5848797
5f9b7fc
 
b78510d
fa32eee
b0b2f83
 
 
 
6b15e42
b85a5ac
0557153
5f8dad5
4180a5d
 
0557153
 
 
6b15e42
18cdc5f
 
0557153
d048a3a
c6996d3
18cdc5f
b0b2f83
fe53c06
18cdc5f
f54cc1e
6b15e42
6e20ae6
 
08fe971
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from transformers import SegformerForSemanticSegmentation
from torch import nn
import os
import io
import sys
import pdb
from matplotlib import pyplot as plt

###################
# Setup label names
target_list = ['Crack', 'ACrack', 'Wetspot', 'Efflorescence', 'Rust', 'Rockpocket', 'Hollowareas', 'Cavity',
               'Spalling', 'Graffiti', 'Weathering', 'Restformwork', 'ExposedRebars', 
               'Bearing', 'EJoint', 'Drainage', 'PEquipment', 'JTape', 'WConccor'  
]
target_list_all = ["All"] + target_list
classes, nclasses = target_list, len(target_list)
label2id = {c: i for i, c in enumerate(target_list)}
id2label = {i: c for i, c in label2id.items()}


# SegModel
device = torch.device("cpu")
segformer = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/mit-b1",
    num_labels=len(target_list),
    id2label=id2label, 
    label2id=label2id
)

class SegModel(nn.Module):
    def __init__(self,segformer):
        super().__init__()
        self.segformer = segformer
        self.upsample = nn.Upsample(scale_factor=4, mode='nearest')

    def forward(self, x):
        return self.upsample(self.segformer(x).logits)
 
model = SegModel(segformer)
state_dict = torch.load("runs/2025-12-30_rich-paper-1/best_model_state_dict.pth",
                       map_location="cpu"
                       )
model.load_state_dict(state_dict)
model.eval()

print("Model ready!")

##################
# Image preprocess
##################

to_tensor = transforms.ToTensor()
to_array = transforms.ToPILImage()
resize = transforms.Resize((512,512))
resize_small = transforms.Resize((369,369))
normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )

def process_pil(img):
    img = to_tensor(img)
    img = resize(img)
    img = normalize(img)
    return img

# the background of the image
def resize_pil(img):
    img = to_tensor(img)
    img = resize_small(img)
    img = to_array(img)
    return img

# combine the foreground (mask_all) and background (original image) to create one image
def transparent(fg, bg, alpha_factor):

    foreground = np.array(fg)
    background = np.array(bg)

    background = Image.fromarray(bg)
    foreground = Image.fromarray(fg)
    new_alpha_factor = int(255*alpha_factor)
    foreground.putalpha(new_alpha_factor)
    background.paste(foreground, (0, 0), foreground)

    return background

def show_img(mask_images, label, bg, alpha):
        
    idx = target_list_all.index(label)

    foreground = mask_images[idx].convert("RGBA")
    background = bg.convert("RGBA")

    foreground.putalpha(int(255 * alpha))
    background.paste(foreground, (0, 0), foreground)


    return background

###########
# Inference

def inference(img):
    background = resize_pil(img) 
    img = process_pil(img)

    mask = model(img.unsqueeze(0))
    mask = mask[0]


    # Get probability values (logits to probs)
    mask_probs = torch.sigmoid(mask)
    mask_probs = mask_probs.detach().numpy()
    mask_probs.shape

    # Make binary mask
    THRESHOLD = 0.5
    mask_preds = mask_probs > THRESHOLD

    # All combined
    mask_all =  mask_preds.sum(axis=0)
    mask_all = np.expand_dims(mask_all, axis=0)
    mask_all.shape

    # Concat all combined with normal preds
    mask_preds = np.concatenate((mask_all, mask_preds),axis=0)
    labs = ["ALL"] + target_list

    fig, axes = plt.subplots(5, 4, figsize = (10,10))
    
    # save all mask_preds in all_mask
    all_masks = []

    for i, ax in enumerate(axes.flat):
        label = labs[i]
        
        all_masks.append(mask_preds[i])

        ax.imshow(mask_preds[i])
        ax.set_title(label)
          
    plt.tight_layout()        

    # plt to PIL
    img_buf = io.BytesIO()
    fig.savefig(img_buf, format='png')
    im = Image.open(img_buf)

    # Saved all masks combined with unvisible xaxis und yaxis and without a white 
    # background.
    all_images = []
    for i in range(len(all_masks)):
        plt.figure()
        fig = plt.imshow(all_masks[i])
        plt.axis('off')
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        img_buf = io.BytesIO()
        plt.savefig(img_buf, bbox_inches='tight', pad_inches = 0, format='png')
        all_images.append(Image.open(img_buf))

    return im, all_images, background

    

title = "Masterarbeit - Bauschadenerkennung"
description = """
KI-basierte Segmentierung von Bauschäden
Arbeitsschritte:
1. Laden Sie ein Bild hoch.
2. Klicken Sie auf den Button "1) Generate Masks"
3. Wählen Sie einen Schaden oder Objekttyp in "Select Label" und wählen Sie einen Alpha Factor
4. Klicken Sie auf "2) Generate Transparent Mask (with Alpha Factor)"
"""

examples=[
["Assets/freiliegende Bewehrung 2.jpg"],
["Assets/freiliegende Bewehrung.jpg"],
["Assets/Graffiti.jpg"],
["Assets/Kiesnest.jpg"],
["Assets/Risse, Abplatzungen.jpg"],
["Assets/dacl10k_v2_validation_0263.jpg"],
["Assets/Risse, Verfärbungen.jpg"],
["Assets/Risse.jpg"],
["Assets/Rost.jpg", "Rost.jpg"],
["Assets/dacl10k_v2_validation_0609.jpg"],
["Assets/dacl10k_v2_validation_0708.jpg"]
]


with gr.Blocks(title=title) as app: 
    with gr.Row():
        gr.Markdown(description)
    with gr.Row():
        input_img = gr.Image(type="pil", label="Original Image")
        gr.Examples(examples=examples, inputs=[input_img])
    with gr.Row():
        img = gr.Image(type="pil", label="All Masks")
        transparent_img = gr.Image(type="pil", label="Transparent Image")
    with gr.Row():
        dropdown = gr.Dropdown(choices=target_list_all, label="Select Label", value="All")
        slider = gr.Slider(minimum=0, maximum=1, value=0.4, label="Alpha Factor")

    mask_state = gr.State() 
    background_state = gr.State()
    
    gr.Button("1) Generate Masks").click(fn=inference, 
                                         inputs=[input_img],
                                         outputs=[img, mask_state, background_state])

    submit_transparent_img = gr.Button("2) Generate Transparent Mask (with Alpha Factor)")
    submit_transparent_img.click(fn=show_img, inputs=[mask_state, dropdown, background_state, slider], outputs=[transparent_img])
   
     
app.launch()