import gradio as gr import torch from mmseg.apis import inference_model, init_model import numpy as np from PIL import Image import os # Load the model config_file = 'config.py' checkpoint_file = 'checkpoint.pth' model = init_model(config_file, checkpoint_file, device='cuda:0' if torch.cuda.is_available() else 'cpu') def segment_image(image): """ Segment the input image using the DeepLabV3 model. """ # The model expects a file path, so we save the Gradio image to a temporary file temp_image_path = "temp_image.png" image = Image.fromarray(image) image.save(temp_image_path) # Perform inference result = inference_model(model, temp_image_path) # The result is a dictionary, and the segmentation map is in `pred_sem_seg` segmentation_map = result.pred_sem_seg.data.squeeze().cpu().numpy() # Clean up the temporary file os.remove(temp_image_path) return segmentation_map.astype(np.uint8) # Create the Gradio interface iface = gr.Interface( fn=segment_image, inputs=gr.Image(), outputs=gr.Image(type="numpy", label="Segmentation Map"), title="DeepLabV3 Sea Ice Segmentation", description="A DeepLabV3 model trained on the seaicergb0 dataset for sea ice segmentation. Upload an image to see the segmentation map.", examples=[ ["test_image.jpg"] ] ) # Create a dummy test image for the example if not os.path.exists("test_image.jpg"): dummy_image = np.random.randint(0, 255, size=(512, 1024, 3), dtype=np.uint8) Image.fromarray(dummy_image).save("test_image.jpg") iface.launch()