| import gradio as gr |
| import torch |
| from torchvision import transforms |
| from PIL import Image |
| import numpy as np |
| from residual_unet import ResidualUNet |
| from huggingface_hub import hf_hub_download |
|
|
| |
| weights_path = hf_hub_download( |
| repo_id="keysun89/resunet_1", |
| filename="best_residual_unet_model.pth" |
| ) |
|
|
| |
| model = ResidualUNet(out_channels=7) |
| model.load_state_dict(torch.load(weights_path, map_location="cpu")) |
| model.eval() |
|
|
| |
| IMG_HEIGHT, IMG_WIDTH = 128, 128 |
| transform = transforms.Compose([ |
| transforms.Resize((IMG_HEIGHT, IMG_WIDTH)), |
| transforms.ToTensor() |
| ]) |
|
|
| def predict(image): |
| orig_w, orig_h = image.size |
| img = transform(image).unsqueeze(0).float() |
| with torch.no_grad(): |
| pred = model(img) |
| |
| class_map = pred.argmax(dim=1).squeeze(0).cpu().numpy().astype(np.uint8) |
| |
| binary_mask = (class_map != 0).astype(np.uint8) * 255 |
| |
| return Image.fromarray(binary_mask).resize((orig_w, orig_h), Image.NEAREST).convert("L") |
|
|
|
|
| |
| demo = gr.Interface( |
| fn=predict, |
| inputs=gr.Image(type="pil"), |
| outputs=gr.Image(type="pil"), |
| title="UNet Crack Segmentation", |
| description="Upload a concrete surface image to get predicted crack mask" |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|