| | import gradio as gr |
| | import torch |
| | from torchvision import transforms |
| | from PIL import Image |
| | import numpy as np |
| |
|
| | |
| | from residual_unet import ResidualUNet |
| |
|
| | |
| | weights_path = hf_hub_download( |
| | repo_id="keysun89/resunet_1", |
| | filename="best_residual_unet_model.pth" |
| | ) |
| |
|
| | |
| | |
| | model = ResidualUNet.from_pretrained(repo_id) |
| | model.eval() |
| |
|
| | |
| | IMG_HEIGHT, IMG_WIDTH = 128, 128 |
| | transform = transforms.Compose([ |
| | transforms.Resize((IMG_HEIGHT, IMG_WIDTH)), |
| | transforms.ToTensor()]) |
| |
|
| | |
| | def predict(image): |
| | orig_w, orig_h = image.size |
| | img = transform(image).unsqueeze(0) |
| | |
| | with torch.no_grad(): |
| | pred = model(img) |
| |
|
| | mask = pred.squeeze(0).squeeze(0).cpu().numpy() |
| | mask = (mask * 255).astype(np.uint8) |
| | mask_img = Image.fromarray(mask).resize((orig_w, orig_h), Image.NEAREST) |
| | return mask_img |
| |
|
| | |
| | demo = gr.Interface( |
| | fn=predict, |
| | inputs=gr.Image(type="pil"), |
| | outputs=gr.Image(type="pil"), |
| | title="Residual-UNet Segmentation", |
| | description="Upload an image to get the predicted segmentation mask.") |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |