Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| import torchvision | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import pandas as pd | |
| import segmentation_models_pytorch as smp | |
| import gradio as gr | |
| num_classes = 2 | |
| model_unet_path = "unet_model.pth" | |
| model_fpn_path = "fpn_model.pth" | |
| model_deeplab_path = "deeplabv3_model.pth" | |
| image_path = "leaf11.jpg" | |
| # Get cpu or gpu device for training. | |
| device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
| print(f"Using {device} device") | |
| model_unet = smp.Unet( | |
| encoder_name="resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 | |
| encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization | |
| in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) | |
| classes=num_classes, # model output channels (number of classes in your dataset) | |
| ) | |
| model_fpn = smp.FPN( | |
| encoder_name="resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 | |
| encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization | |
| in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) | |
| classes=num_classes, # model output channels (number of classes in your dataset) | |
| ) | |
| model_deeplab = smp.DeepLabV3( | |
| encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 | |
| encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization | |
| in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) | |
| classes=num_classes, # model output channels (number of classes in your dataset) | |
| ) | |
| def pred_one_image(inp,option): | |
| one_image = np.array(inp.resize((256, 256)).convert("RGB")) | |
| # convert to other format HWC -> CHW | |
| one_image = np.moveaxis(one_image, -1, 0) | |
| # mask = np.expand_dims(mask, 0) | |
| one_image = torch.tensor(one_image).float() | |
| one_image = one_image.unsqueeze(0) | |
| one_image = one_image.to(device) | |
| if option == "unet": | |
| model_load = model_unet | |
| elif option == "fpn": | |
| model_load = model_fpn | |
| elif option == "deeplab": | |
| model_load = model_deeplab | |
| model_load.eval() | |
| with torch.no_grad(): | |
| output = model_load(one_image) | |
| # print(output.shape) | |
| predictions = torch.argmax(output, dim=1) # 获取预测的类别标签图像 | |
| pred_array = (predictions[0].cpu().numpy()/2*255).astype(np.uint8) | |
| # print(pred_array.shape) | |
| pred_img = Image.fromarray(pred_array) | |
| # pred_img.save("pred.png") | |
| # print(predictions.shape) | |
| return pred_img | |
| model_unet.load_state_dict(torch.load(model_unet_path,map_location=torch.device('cpu'))) | |
| model_fpn.load_state_dict(torch.load(model_fpn_path,map_location=torch.device('cpu'))) | |
| model_deeplab.load_state_dict(torch.load(model_deeplab_path,map_location=torch.device('cpu'))) | |
| dropdown = gr.Dropdown(["unet", "fpn","deeplab"]) | |
| interface = gr.Interface(fn=pred_one_image, | |
| inputs=[gr.Image(type="pil"),dropdown], | |
| outputs=gr.Image(type="pil"), | |
| examples=[["leaf11.jpg",'unet']],) | |
| interface.launch(debug=False) | |