resunet_1 / app.py
keysun89's picture
Update app.py
fa8b1b9 verified
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
from residual_unet import ResidualUNet
from huggingface_hub import hf_hub_download
# Load trained model weights from Hugging Face Hub
weights_path = hf_hub_download(
repo_id="keysun89/resunet_1", # ensure this matches your repo
filename="best_residual_unet_model.pth" # make sure this file exists in repo
)
# Initialize and load model
model = ResidualUNet(out_channels=7)
model.load_state_dict(torch.load(weights_path, map_location="cpu"))
model.eval()
# Preprocessing: same as training
IMG_HEIGHT, IMG_WIDTH = 128, 128
transform = transforms.Compose([
transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
transforms.ToTensor()
])
def predict(image):
orig_w, orig_h = image.size
img = transform(image).unsqueeze(0).float()
with torch.no_grad():
pred = model(img) # [1,7,H,W]
# convert to class map
class_map = pred.argmax(dim=1).squeeze(0).cpu().numpy().astype(np.uint8) # H,W
# if you want a binary crack mask from classes (treat class 0 as background)
binary_mask = (class_map != 0).astype(np.uint8) * 255
# return grayscale mask (or return class_map visual)
return Image.fromarray(binary_mask).resize((orig_w, orig_h), Image.NEAREST).convert("L")
# Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="UNet Crack Segmentation",
description="Upload a concrete surface image to get predicted crack mask"
)
if __name__ == "__main__":
demo.launch()