crack_demo / app.py
toth235a's picture
Update app.py
a00629c verified
import gradio as gr
from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor
import torch
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import io
# Load your pre-trained model and processor
id2label = {0: 'background', 1: 'crack'}
model = Mask2FormerForUniversalSegmentation.from_pretrained(
"toth235a/mask2former-swin-large-crack-semantic",
id2label=id2label,
ignore_mismatched_sizes=True
)
processor = Mask2FormerImageProcessor(
ignore_index=255, reduce_labels=False, do_resize=True,
do_rescale=True, do_normalize=True
)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def predict(image):
# Process the image through your model
inputs = processor(images=image, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(device)
with torch.no_grad():
outputs = model(pixel_values=pixel_values)
# Post-process the outputs to get segmentation map
predicted_maps = processor.post_process_semantic_segmentation(outputs)
segmentation_map = predicted_maps[0].cpu().numpy()
# Plotting the segmentation map
plt.figure(figsize=(10, 10))
#plt.imshow(image)
plt.imshow(segmentation_map, alpha=0.5, cmap='gray') # Overlay the segmentation on the original image
plt.axis('off')
plt.show()
# Convert the plot to a NumPy array and return it
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
img_arr = np.array(Image.open(buf))
plt.close()
return img_arr
iface = gr.Interface(
fn=predict,
inputs="image",
outputs="image",
title="Crack Segmentation Demo",
description="Upload an image to show cracks."
)
iface.launch()