Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision import transforms | |
| from transformers import SegformerForSemanticSegmentation | |
| from torch import nn | |
| import os | |
| import io | |
| import sys | |
| import pdb | |
| from matplotlib import pyplot as plt | |
| ################### | |
| # Setup label names | |
| target_list = ['Crack', 'ACrack', 'Wetspot', 'Efflorescence', 'Rust', 'Rockpocket', 'Hollowareas', 'Cavity', | |
| 'Spalling', 'Graffiti', 'Weathering', 'Restformwork', 'ExposedRebars', | |
| 'Bearing', 'EJoint', 'Drainage', 'PEquipment', 'JTape', 'WConccor' | |
| ] | |
| target_list_all = ["All"] + target_list | |
| classes, nclasses = target_list, len(target_list) | |
| label2id = {c: i for i, c in enumerate(target_list)} | |
| id2label = {i: c for i, c in label2id.items()} | |
| # SegModel | |
| device = torch.device("cpu") | |
| segformer = SegformerForSemanticSegmentation.from_pretrained( | |
| "nvidia/mit-b1", | |
| num_labels=len(target_list), | |
| id2label=id2label, | |
| label2id=label2id | |
| ) | |
| class SegModel(nn.Module): | |
| def __init__(self,segformer): | |
| super().__init__() | |
| self.segformer = segformer | |
| self.upsample = nn.Upsample(scale_factor=4, mode='nearest') | |
| def forward(self, x): | |
| return self.upsample(self.segformer(x).logits) | |
| model = SegModel(segformer) | |
| state_dict = torch.load("runs/2025-12-30_rich-paper-1/best_model_state_dict.pth", | |
| map_location="cpu" | |
| ) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| print("Model ready!") | |
| ################## | |
| # Image preprocess | |
| ################## | |
| to_tensor = transforms.ToTensor() | |
| to_array = transforms.ToPILImage() | |
| resize = transforms.Resize((512,512)) | |
| resize_small = transforms.Resize((369,369)) | |
| normalize = transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| def process_pil(img): | |
| img = to_tensor(img) | |
| img = resize(img) | |
| img = normalize(img) | |
| return img | |
| # the background of the image | |
| def resize_pil(img): | |
| img = to_tensor(img) | |
| img = resize_small(img) | |
| img = to_array(img) | |
| return img | |
| # combine the foreground (mask_all) and background (original image) to create one image | |
| def transparent(fg, bg, alpha_factor): | |
| foreground = np.array(fg) | |
| background = np.array(bg) | |
| background = Image.fromarray(bg) | |
| foreground = Image.fromarray(fg) | |
| new_alpha_factor = int(255*alpha_factor) | |
| foreground.putalpha(new_alpha_factor) | |
| background.paste(foreground, (0, 0), foreground) | |
| return background | |
| def show_img(mask_images, label, bg, alpha): | |
| idx = target_list_all.index(label) | |
| foreground = mask_images[idx].convert("RGBA") | |
| background = bg.convert("RGBA") | |
| foreground.putalpha(int(255 * alpha)) | |
| background.paste(foreground, (0, 0), foreground) | |
| return background | |
| ########### | |
| # Inference | |
| def inference(img): | |
| background = resize_pil(img) | |
| img = process_pil(img) | |
| mask = model(img.unsqueeze(0)) | |
| mask = mask[0] | |
| # Get probability values (logits to probs) | |
| mask_probs = torch.sigmoid(mask) | |
| mask_probs = mask_probs.detach().numpy() | |
| mask_probs.shape | |
| # Make binary mask | |
| THRESHOLD = 0.5 | |
| mask_preds = mask_probs > THRESHOLD | |
| # All combined | |
| mask_all = mask_preds.sum(axis=0) | |
| mask_all = np.expand_dims(mask_all, axis=0) | |
| mask_all.shape | |
| # Concat all combined with normal preds | |
| mask_preds = np.concatenate((mask_all, mask_preds),axis=0) | |
| labs = ["ALL"] + target_list | |
| fig, axes = plt.subplots(5, 4, figsize = (10,10)) | |
| # save all mask_preds in all_mask | |
| all_masks = [] | |
| for i, ax in enumerate(axes.flat): | |
| label = labs[i] | |
| all_masks.append(mask_preds[i]) | |
| ax.imshow(mask_preds[i]) | |
| ax.set_title(label) | |
| plt.tight_layout() | |
| # plt to PIL | |
| img_buf = io.BytesIO() | |
| fig.savefig(img_buf, format='png') | |
| im = Image.open(img_buf) | |
| # Saved all masks combined with unvisible xaxis und yaxis and without a white | |
| # background. | |
| all_images = [] | |
| for i in range(len(all_masks)): | |
| plt.figure() | |
| fig = plt.imshow(all_masks[i]) | |
| plt.axis('off') | |
| fig.axes.get_xaxis().set_visible(False) | |
| fig.axes.get_yaxis().set_visible(False) | |
| img_buf = io.BytesIO() | |
| plt.savefig(img_buf, bbox_inches='tight', pad_inches = 0, format='png') | |
| all_images.append(Image.open(img_buf)) | |
| return im, all_images, background | |
| title = "Masterarbeit - Bauschadenerkennung" | |
| description = """ | |
| KI-basierte Segmentierung von Bauschäden | |
| Arbeitsschritte: | |
| 1. Laden Sie ein Bild hoch. | |
| 2. Klicken Sie auf den Button "1) Generate Masks" | |
| 3. Wählen Sie einen Schaden oder Objekttyp in "Select Label" und wählen Sie einen Alpha Factor | |
| 4. Klicken Sie auf "2) Generate Transparent Mask (with Alpha Factor)" | |
| """ | |
| examples=[ | |
| ["Assets/freiliegende Bewehrung 2.jpg"], | |
| ["Assets/freiliegende Bewehrung.jpg"], | |
| ["Assets/Graffiti.jpg"], | |
| ["Assets/Kiesnest.jpg"], | |
| ["Assets/Risse, Abplatzungen.jpg"], | |
| ["Assets/dacl10k_v2_validation_0263.jpg"], | |
| ["Assets/Risse, Verfärbungen.jpg"], | |
| ["Assets/Risse.jpg"], | |
| ["Assets/Rost.jpg", "Rost.jpg"], | |
| ["Assets/dacl10k_v2_validation_0609.jpg"], | |
| ["Assets/dacl10k_v2_validation_0708.jpg"] | |
| ] | |
| with gr.Blocks(title=title) as app: | |
| with gr.Row(): | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| input_img = gr.Image(type="pil", label="Original Image") | |
| gr.Examples(examples=examples, inputs=[input_img]) | |
| with gr.Row(): | |
| img = gr.Image(type="pil", label="All Masks") | |
| transparent_img = gr.Image(type="pil", label="Transparent Image") | |
| with gr.Row(): | |
| dropdown = gr.Dropdown(choices=target_list_all, label="Select Label", value="All") | |
| slider = gr.Slider(minimum=0, maximum=1, value=0.4, label="Alpha Factor") | |
| mask_state = gr.State() | |
| background_state = gr.State() | |
| gr.Button("1) Generate Masks").click(fn=inference, | |
| inputs=[input_img], | |
| outputs=[img, mask_state, background_state]) | |
| submit_transparent_img = gr.Button("2) Generate Transparent Mask (with Alpha Factor)") | |
| submit_transparent_img.click(fn=show_img, inputs=[mask_state, dropdown, background_state, slider], outputs=[transparent_img]) | |
| app.launch() | |