kitooo's picture
updated app
aaaedd1 verified
raw
history blame
1.65 kB
import torch
import gradio as gr
import matplotlib.pyplot as plt
from PIL import Image
from transformers import SamModel, SamConfig, SamProcessor
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
# model = SamModel(config=model_config)
model = SamModel.from_pretrained('kitooo/sidewalk-seg-base')
model.to(device)
def segment_sidewalk(image):
width, height = image.size
prompt = [0, 0, width, height]
inputs = processor(image, input_boxes=[[prompt]], return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs, multimask_output=False)
prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
prediction = (prob_map > 0.5).float()
prob_map, prediction = prob_map.numpy(), prediction.numpy()
save_image(image, 'image.png')
save_image(prob_map, 'prob.png', cmap='jet')
save_image(prediction, 'mask.png', cmap='gray')
return Image.open('image.png'), Image.open('mask.png'), Image.open('prob.png')
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image_input = gr.Image(type='pil', label='Original TIFF Image')
button = gr.Button('Get Sidewalk Mask')
with gr.Column():
mask = gr.Image(type='pil', label='Predicted Mask')
prob_map = gr.Image(type='pil', label='Predicted Probability Map')
button.click(
segment_sidewalk,
inputs=[image_input],
outputs=[image_input, mask, prob_map]
)
demo.launch(debug=True, show_error=True)