triton7777's picture
files
a8e4304
import gradio as gr
import torch
from mmseg.apis import inference_model, init_model
import numpy as np
from PIL import Image
import os
# Load the model
config_file = 'config.py'
checkpoint_file = 'checkpoint.pth'
model = init_model(config_file, checkpoint_file, device='cuda:0' if torch.cuda.is_available() else 'cpu')
def segment_image(image):
"""
Segment the input image using the DeepLabV3 model.
"""
# The model expects a file path, so we save the Gradio image to a temporary file
temp_image_path = "temp_image.png"
image = Image.fromarray(image)
image.save(temp_image_path)
# Perform inference
result = inference_model(model, temp_image_path)
# The result is a dictionary, and the segmentation map is in `pred_sem_seg`
segmentation_map = result.pred_sem_seg.data.squeeze().cpu().numpy()
# Clean up the temporary file
os.remove(temp_image_path)
return segmentation_map.astype(np.uint8)
# Create the Gradio interface
iface = gr.Interface(
fn=segment_image,
inputs=gr.Image(),
outputs=gr.Image(type="numpy", label="Segmentation Map"),
title="DeepLabV3 Sea Ice Segmentation",
description="A DeepLabV3 model trained on the seaicergb0 dataset for sea ice segmentation. Upload an image to see the segmentation map.",
examples=[
["test_image.jpg"]
]
)
# Create a dummy test image for the example
if not os.path.exists("test_image.jpg"):
dummy_image = np.random.randint(0, 255, size=(512, 1024, 3), dtype=np.uint8)
Image.fromarray(dummy_image).save("test_image.jpg")
iface.launch()