import gradio as gr import torch import numpy as np from PIL import Image from torchvision import transforms from transformers import SegformerForSemanticSegmentation from torch import nn import os import io import sys import pdb from matplotlib import pyplot as plt ################### # Setup label names target_list = ['Crack', 'ACrack', 'Wetspot', 'Efflorescence', 'Rust', 'Rockpocket', 'Hollowareas', 'Cavity', 'Spalling', 'Graffiti', 'Weathering', 'Restformwork', 'ExposedRebars', 'Bearing', 'EJoint', 'Drainage', 'PEquipment', 'JTape', 'WConccor' ] target_list_all = ["All"] + target_list classes, nclasses = target_list, len(target_list) label2id = {c: i for i, c in enumerate(target_list)} id2label = {i: c for i, c in label2id.items()} # SegModel device = torch.device("cpu") segformer = SegformerForSemanticSegmentation.from_pretrained( "nvidia/mit-b1", num_labels=len(target_list), id2label=id2label, label2id=label2id ) class SegModel(nn.Module): def __init__(self,segformer): super().__init__() self.segformer = segformer self.upsample = nn.Upsample(scale_factor=4, mode='nearest') def forward(self, x): return self.upsample(self.segformer(x).logits) model = SegModel(segformer) state_dict = torch.load("runs/2025-12-30_rich-paper-1/best_model_state_dict.pth", map_location="cpu" ) model.load_state_dict(state_dict) model.eval() print("Model ready!") ################## # Image preprocess ################## to_tensor = transforms.ToTensor() to_array = transforms.ToPILImage() resize = transforms.Resize((512,512)) resize_small = transforms.Resize((369,369)) normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) def process_pil(img): img = to_tensor(img) img = resize(img) img = normalize(img) return img # the background of the image def resize_pil(img): img = to_tensor(img) img = resize_small(img) img = to_array(img) return img # combine the foreground (mask_all) and background (original image) to create one image def transparent(fg, bg, alpha_factor): foreground = np.array(fg) background = np.array(bg) background = Image.fromarray(bg) foreground = Image.fromarray(fg) new_alpha_factor = int(255*alpha_factor) foreground.putalpha(new_alpha_factor) background.paste(foreground, (0, 0), foreground) return background def show_img(mask_images, label, bg, alpha): idx = target_list_all.index(label) foreground = mask_images[idx].convert("RGBA") background = bg.convert("RGBA") foreground.putalpha(int(255 * alpha)) background.paste(foreground, (0, 0), foreground) return background ########### # Inference def inference(img): background = resize_pil(img) img = process_pil(img) mask = model(img.unsqueeze(0)) mask = mask[0] # Get probability values (logits to probs) mask_probs = torch.sigmoid(mask) mask_probs = mask_probs.detach().numpy() mask_probs.shape # Make binary mask THRESHOLD = 0.5 mask_preds = mask_probs > THRESHOLD # All combined mask_all = mask_preds.sum(axis=0) mask_all = np.expand_dims(mask_all, axis=0) mask_all.shape # Concat all combined with normal preds mask_preds = np.concatenate((mask_all, mask_preds),axis=0) labs = ["ALL"] + target_list fig, axes = plt.subplots(5, 4, figsize = (10,10)) # save all mask_preds in all_mask all_masks = [] for i, ax in enumerate(axes.flat): label = labs[i] all_masks.append(mask_preds[i]) ax.imshow(mask_preds[i]) ax.set_title(label) plt.tight_layout() # plt to PIL img_buf = io.BytesIO() fig.savefig(img_buf, format='png') im = Image.open(img_buf) # Saved all masks combined with unvisible xaxis und yaxis and without a white # background. all_images = [] for i in range(len(all_masks)): plt.figure() fig = plt.imshow(all_masks[i]) plt.axis('off') fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) img_buf = io.BytesIO() plt.savefig(img_buf, bbox_inches='tight', pad_inches = 0, format='png') all_images.append(Image.open(img_buf)) return im, all_images, background title = "Masterarbeit - Bauschadenerkennung" description = """ KI-basierte Segmentierung von Bauschäden Arbeitsschritte: 1. Laden Sie ein Bild hoch. 2. Klicken Sie auf den Button "1) Generate Masks" 3. Wählen Sie einen Schaden oder Objekttyp in "Select Label" und wählen Sie einen Alpha Factor 4. Klicken Sie auf "2) Generate Transparent Mask (with Alpha Factor)" """ examples=[ ["Assets/freiliegende Bewehrung 2.jpg"], ["Assets/freiliegende Bewehrung.jpg"], ["Assets/Graffiti.jpg"], ["Assets/Kiesnest.jpg"], ["Assets/Risse, Abplatzungen.jpg"], ["Assets/dacl10k_v2_validation_0263.jpg"], ["Assets/Risse, Verfärbungen.jpg"], ["Assets/Risse.jpg"], ["Assets/Rost.jpg", "Rost.jpg"], ["Assets/dacl10k_v2_validation_0609.jpg"], ["Assets/dacl10k_v2_validation_0708.jpg"] ] with gr.Blocks(title=title) as app: with gr.Row(): gr.Markdown(description) with gr.Row(): input_img = gr.Image(type="pil", label="Original Image") gr.Examples(examples=examples, inputs=[input_img]) with gr.Row(): img = gr.Image(type="pil", label="All Masks") transparent_img = gr.Image(type="pil", label="Transparent Image") with gr.Row(): dropdown = gr.Dropdown(choices=target_list_all, label="Select Label", value="All") slider = gr.Slider(minimum=0, maximum=1, value=0.4, label="Alpha Factor") mask_state = gr.State() background_state = gr.State() gr.Button("1) Generate Masks").click(fn=inference, inputs=[input_img], outputs=[img, mask_state, background_state]) submit_transparent_img = gr.Button("2) Generate Transparent Mask (with Alpha Factor)") submit_transparent_img.click(fn=show_img, inputs=[mask_state, dropdown, background_state, slider], outputs=[transparent_img]) app.launch()