Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from model import SegmentationModel | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model = SegmentationModel() | |
| model.to(DEVICE) | |
| model.load_state_dict(torch.load('./best_model.pt')) | |
| def inference(input_img): | |
| image = torch.from_numpy(input_img).permute(2,0,1).float() | |
| logits_mask = model(image.to(DEVICE).unsqueeze(0)) # (C, H, W) -> (1, C, H, W) | |
| pred_mask = torch.sigmoid(logits_mask) | |
| return pred_mask.squeeze().detach().cpu().numpy() | |
| demo = gr.Interface(inference, gr.Image(shape=(224, 224)), "image") | |
| demo.launch() |