Masterarbeit / app.py
Alic22's picture
Update app.py
18cdc5f verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from transformers import SegformerForSemanticSegmentation
from torch import nn
import os
import io
import sys
import pdb
from matplotlib import pyplot as plt
###################
# Setup label names
target_list = ['Crack', 'ACrack', 'Wetspot', 'Efflorescence', 'Rust', 'Rockpocket', 'Hollowareas', 'Cavity',
'Spalling', 'Graffiti', 'Weathering', 'Restformwork', 'ExposedRebars',
'Bearing', 'EJoint', 'Drainage', 'PEquipment', 'JTape', 'WConccor'
]
target_list_all = ["All"] + target_list
classes, nclasses = target_list, len(target_list)
label2id = {c: i for i, c in enumerate(target_list)}
id2label = {i: c for i, c in label2id.items()}
# SegModel
device = torch.device("cpu")
segformer = SegformerForSemanticSegmentation.from_pretrained(
"nvidia/mit-b1",
num_labels=len(target_list),
id2label=id2label,
label2id=label2id
)
class SegModel(nn.Module):
def __init__(self,segformer):
super().__init__()
self.segformer = segformer
self.upsample = nn.Upsample(scale_factor=4, mode='nearest')
def forward(self, x):
return self.upsample(self.segformer(x).logits)
model = SegModel(segformer)
state_dict = torch.load("runs/2025-12-30_rich-paper-1/best_model_state_dict.pth",
map_location="cpu"
)
model.load_state_dict(state_dict)
model.eval()
print("Model ready!")
##################
# Image preprocess
##################
to_tensor = transforms.ToTensor()
to_array = transforms.ToPILImage()
resize = transforms.Resize((512,512))
resize_small = transforms.Resize((369,369))
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
def process_pil(img):
img = to_tensor(img)
img = resize(img)
img = normalize(img)
return img
# the background of the image
def resize_pil(img):
img = to_tensor(img)
img = resize_small(img)
img = to_array(img)
return img
# combine the foreground (mask_all) and background (original image) to create one image
def transparent(fg, bg, alpha_factor):
foreground = np.array(fg)
background = np.array(bg)
background = Image.fromarray(bg)
foreground = Image.fromarray(fg)
new_alpha_factor = int(255*alpha_factor)
foreground.putalpha(new_alpha_factor)
background.paste(foreground, (0, 0), foreground)
return background
def show_img(mask_images, label, bg, alpha):
idx = target_list_all.index(label)
foreground = mask_images[idx].convert("RGBA")
background = bg.convert("RGBA")
foreground.putalpha(int(255 * alpha))
background.paste(foreground, (0, 0), foreground)
return background
###########
# Inference
def inference(img):
background = resize_pil(img)
img = process_pil(img)
mask = model(img.unsqueeze(0))
mask = mask[0]
# Get probability values (logits to probs)
mask_probs = torch.sigmoid(mask)
mask_probs = mask_probs.detach().numpy()
mask_probs.shape
# Make binary mask
THRESHOLD = 0.5
mask_preds = mask_probs > THRESHOLD
# All combined
mask_all = mask_preds.sum(axis=0)
mask_all = np.expand_dims(mask_all, axis=0)
mask_all.shape
# Concat all combined with normal preds
mask_preds = np.concatenate((mask_all, mask_preds),axis=0)
labs = ["ALL"] + target_list
fig, axes = plt.subplots(5, 4, figsize = (10,10))
# save all mask_preds in all_mask
all_masks = []
for i, ax in enumerate(axes.flat):
label = labs[i]
all_masks.append(mask_preds[i])
ax.imshow(mask_preds[i])
ax.set_title(label)
plt.tight_layout()
# plt to PIL
img_buf = io.BytesIO()
fig.savefig(img_buf, format='png')
im = Image.open(img_buf)
# Saved all masks combined with unvisible xaxis und yaxis and without a white
# background.
all_images = []
for i in range(len(all_masks)):
plt.figure()
fig = plt.imshow(all_masks[i])
plt.axis('off')
fig.axes.get_xaxis().set_visible(False)
fig.axes.get_yaxis().set_visible(False)
img_buf = io.BytesIO()
plt.savefig(img_buf, bbox_inches='tight', pad_inches = 0, format='png')
all_images.append(Image.open(img_buf))
return im, all_images, background
title = "Masterarbeit - Bauschadenerkennung"
description = """
KI-basierte Segmentierung von Bauschäden
Arbeitsschritte:
1. Laden Sie ein Bild hoch.
2. Klicken Sie auf den Button "1) Generate Masks"
3. Wählen Sie einen Schaden oder Objekttyp in "Select Label" und wählen Sie einen Alpha Factor
4. Klicken Sie auf "2) Generate Transparent Mask (with Alpha Factor)"
"""
examples=[
["Assets/freiliegende Bewehrung 2.jpg"],
["Assets/freiliegende Bewehrung.jpg"],
["Assets/Graffiti.jpg"],
["Assets/Kiesnest.jpg"],
["Assets/Risse, Abplatzungen.jpg"],
["Assets/dacl10k_v2_validation_0263.jpg"],
["Assets/Risse, Verfärbungen.jpg"],
["Assets/Risse.jpg"],
["Assets/Rost.jpg", "Rost.jpg"],
["Assets/dacl10k_v2_validation_0609.jpg"],
["Assets/dacl10k_v2_validation_0708.jpg"]
]
with gr.Blocks(title=title) as app:
with gr.Row():
gr.Markdown(description)
with gr.Row():
input_img = gr.Image(type="pil", label="Original Image")
gr.Examples(examples=examples, inputs=[input_img])
with gr.Row():
img = gr.Image(type="pil", label="All Masks")
transparent_img = gr.Image(type="pil", label="Transparent Image")
with gr.Row():
dropdown = gr.Dropdown(choices=target_list_all, label="Select Label", value="All")
slider = gr.Slider(minimum=0, maximum=1, value=0.4, label="Alpha Factor")
mask_state = gr.State()
background_state = gr.State()
gr.Button("1) Generate Masks").click(fn=inference,
inputs=[input_img],
outputs=[img, mask_state, background_state])
submit_transparent_img = gr.Button("2) Generate Transparent Mask (with Alpha Factor)")
submit_transparent_img.click(fn=show_img, inputs=[mask_state, dropdown, background_state, slider], outputs=[transparent_img])
app.launch()