Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from mmseg.apis import inference_model, init_model | |
| import numpy as np | |
| from PIL import Image | |
| import os | |
| # Load the model | |
| config_file = 'config.py' | |
| checkpoint_file = 'checkpoint.pth' | |
| model = init_model(config_file, checkpoint_file, device='cuda:0' if torch.cuda.is_available() else 'cpu') | |
| def segment_image(image): | |
| """ | |
| Segment the input image using the DeepLabV3 model. | |
| """ | |
| # The model expects a file path, so we save the Gradio image to a temporary file | |
| temp_image_path = "temp_image.png" | |
| image = Image.fromarray(image) | |
| image.save(temp_image_path) | |
| # Perform inference | |
| result = inference_model(model, temp_image_path) | |
| # The result is a dictionary, and the segmentation map is in `pred_sem_seg` | |
| segmentation_map = result.pred_sem_seg.data.squeeze().cpu().numpy() | |
| # Clean up the temporary file | |
| os.remove(temp_image_path) | |
| return segmentation_map.astype(np.uint8) | |
| # Create the Gradio interface | |
| iface = gr.Interface( | |
| fn=segment_image, | |
| inputs=gr.Image(), | |
| outputs=gr.Image(type="numpy", label="Segmentation Map"), | |
| title="DeepLabV3 Sea Ice Segmentation", | |
| description="A DeepLabV3 model trained on the seaicergb0 dataset for sea ice segmentation. Upload an image to see the segmentation map.", | |
| examples=[ | |
| ["test_image.jpg"] | |
| ] | |
| ) | |
| # Create a dummy test image for the example | |
| if not os.path.exists("test_image.jpg"): | |
| dummy_image = np.random.randint(0, 255, size=(512, 1024, 3), dtype=np.uint8) | |
| Image.fromarray(dummy_image).save("test_image.jpg") | |
| iface.launch() | |