Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| # Load the segmentation model (replace 'path/to/lightmed_model' with the actual path) | |
| model_path = 'medsam_lite/lite_medsam.pth' | |
| segmentation_model = torch.load(model_path, map_location=torch.device('cpu')) | |
| segmentation_model.eval() | |
| # Define the preprocessing function for the input image | |
| def preprocess(image): | |
| # Resize the image to match the model's expected input size | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| ]) | |
| img = Image.fromarray(image) | |
| img = transform(img).unsqueeze(0) | |
| return img | |
| # Define the segmentation function | |
| def segment_image(input_image): | |
| # Preprocess the input image | |
| input_tensor = preprocess(input_image) | |
| # Perform segmentation using the model | |
| with torch.no_grad(): | |
| output = segmentation_model(input_tensor) | |
| # Convert the output tensor to a segmented image | |
| segmented_image = torch.argmax(output, dim=1).squeeze().numpy() | |
| # Return the segmented image | |
| return segmented_image | |
| # Define the Gradio interface | |
| iface = gr.Interface( | |
| fn=segment_image, | |
| inputs=gr.Image(type="pil", preprocess=preprocess), | |
| outputs=gr.Image(type="numpy") | |
| ) | |
| # Launch the Gradio app | |
| iface.launch() | |