Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import segmentation_models_pytorch as smp | |
| import torch | |
| import PIL as Image | |
| #load our pytorch model: | |
| model = smp.Unet( | |
| encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 | |
| encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization | |
| in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) | |
| classes=10, # model output channels (number of classes in your dataset) | |
| ) | |
| model.load_state_dict(torch.load('Floodnet_model_e5.pt', map_location=torch.device('cpu'))) | |
| model.eval() | |
| #handle input: | |
| # output = lbm(sample.unsqueeze(dim=0).float()).detach().type(torch.int64) | |
| # show(output.argmax(dim=1).squeeze()) | |
| def predict_segmentation(image: Image.Image): | |
| image = image.resize((256, 256)) | |
| input_data = np.asarray(image) | |
| # Assuming the model expects a 4D input array | |
| input_data = input_data[np.newaxis, ...] | |
| # Get the prediction from the model | |
| output_data = model.predict(torch.from_numpy(input_data).float()) | |
| # Assuming the output is a 3D array | |
| output_mask = output.argmax(dim=1).squeeze() | |
| # Convert the output_mask to an Image object | |
| output_image = output_mask#Image.fromarray(np.uint8(output_mask.numpy())) | |
| return output_image | |
| image_input = gr.components.Image(shape=(256, 256), source="upload") | |
| image_output = gr.components.Image(type="pil") | |
| iface = gr.Interface(predict_segmentation, 'image', 'image') | |
| iface.launch() |