AJain1234's picture
Upload folder using huggingface_hub
4bb934b verified
raw
history blame
8.49 kB
import gradio as gr
import torch
from torchvision import transforms
from experiments.otsu_segmenter import generate_segmented_image
from experiments.kmeans_segmenter import generate_kmeans_segmented_image
from experiments.enhanced_kmeans_segmenter import slic_kmeans
from experiments.watershed_segmenter import generate_watershed
from experiments.felzenszwalb_segmentation import segment
from experiments.SegNet.architecture import SegNetEfficientNet, NUM_CLASSES, DEVICE, IMAGE_SIZE
import numpy as np
from PIL import Image
from matplotlib import cm
def generate_kmeans(image_path,k):
kmeans_image_output, kmeans_segmented_image_output,_,kmeans_threshold_text=generate_kmeans_segmented_image(image_path, k)
return kmeans_image_output, kmeans_segmented_image_output, kmeans_threshold_text
def generate_slic(image_path,k,m,max_iter):
image,seg_img, labels, centers = slic_kmeans(image_path, K=k, m=m, max_iter=max_iter)
return image,seg_img
def generate_felzenszwalb(image_path, sigma, k, min_size_factor):
image = Image.open(image_path).convert("RGB")
image_np = np.array(image)
segments_fz = segment(image_np, sigma=sigma, k=k, min_size=min_size_factor)
segments_fz = segments_fz.astype(np.uint8)
return image, segments_fz
def SegNet_efficient_b0(image_path):
model = SegNetEfficientNet(NUM_CLASSES).to(DEVICE)
model.load_state_dict(torch.load("segnet_efficientnet_voc.pth", map_location=DEVICE))
model.eval()
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert("RGB")
input_tensor = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(input_tensor)
pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
# Convert original image for Gradio display
original_image_resized = image.resize(IMAGE_SIZE)
# Convert predicted mask to a color image using a colormap
colormap = cm.get_cmap('nipy_spectral')
colored_mask = colormap(pred_mask / pred_mask.max()) # Normalize
colored_mask = (colored_mask[:, :, :3] * 255).astype(np.uint8) # Drop alpha and convert to uint8
mask_pil = Image.fromarray(colored_mask)
return original_image_resized, mask_pil
with gr.Blocks() as demo:
gr.Markdown("# Image Segmentation using Classical CV")
with gr.Tabs() as tabs:
with gr.TabItem("Otsu's Method"):
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(label="Upload Image File")
display_btn = gr.Button("Segment this image")
threshold_text = gr.Textbox(label="Threshold Comparison", value="", interactive=False)
with gr.Column(scale=2):
image_output = gr.Image(label="Original Image", container=False)
histogram_output = gr.Image(label="Histogram", container=False)
segmented_image_output = gr.Image(label="Our Segmented Image", container=False)
opencv_segmented_image_output = gr.Image(label="OpenCV Segmented Image", container=False)
display_btn.click(
fn=generate_segmented_image,
inputs=file_input,
outputs=[image_output, segmented_image_output, opencv_segmented_image_output, histogram_output, threshold_text]
)
with gr.TabItem("K-means Segmentation"):
with gr.Row():
with gr.Column(scale=1):
kmeans_file_input = gr.File(label="Upload Image File")
kmeans_k_value = gr.Slider(minimum=2, maximum=10, value=3, step=1, label="Number of Clusters (K)")
kmeans_display_btn = gr.Button("Segment this image")
kmeans_threshold_text = gr.Textbox(label="K-means Info", value="", interactive=False)
with gr.Column(scale=2):
kmeans_image_output = gr.Image(label="Original Image", container=False)
kmeans_segmented_image_output = gr.Image(label="K-means Segmented Image", container=False)
kmeans_display_btn.click(
fn=generate_kmeans,
inputs=[kmeans_file_input, kmeans_k_value],
outputs=[kmeans_image_output, kmeans_segmented_image_output, kmeans_threshold_text]
)
with gr.TabItem("SLIC Segmentation"):
with gr.Row():
with gr.Column(scale=1):
slic_file_input = gr.File(label="Upload Image File")
slic_k_value = gr.Slider(minimum=2, maximum=200, value=3, step=1, label="Number of superpixels")
slic_m_value = gr.Slider(minimum=1, maximum=40, value=3, step=1, label="Compactness factor")
slic_max_iter_value = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of iterations")
slic_display_btn = gr.Button("Segment this image")
with gr.Column(scale=2):
slic_image_output = gr.Image(label="Original Image", container=False)
slic_segmented_image_output = gr.Image(label="SLIC Segmented Image", container=False)
slic_display_btn.click(
fn=generate_slic,
inputs=[slic_file_input, slic_k_value,slic_m_value,slic_max_iter_value],
outputs=[slic_image_output,slic_segmented_image_output]
)
with gr.TabItem("Watershed Algorithm Segmentation"):
with gr.Row():
with gr.Column(scale=1):
watershed_file_input = gr.File(label="Upload Image File")
watershed_display_btn = gr.Button("Segment this image")
with gr.Column(scale=2):
watershed_image_output = gr.Image(label="Original Image", container=False)
watershed_segmented_image_output = gr.Image(label="watershed Segmented Image", container=False)
watershed_display_btn.click(
fn=generate_watershed,
inputs=[watershed_file_input],
outputs=[watershed_image_output,watershed_segmented_image_output]
)
with gr.TabItem("Felzenszwalb Algorithm Segmentation"):
with gr.Row():
with gr.Column(scale=1):
felzenszwalb_file_input = gr.File(label="Upload Image File")
sigma_value = gr.Slider(minimum=0, maximum=1, value=0.2, step=0.1, label="Sigma")
K_value = gr.Slider(minimum=2, maximum=1000, value=2, step=1, label="K value")
min_size_value = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Min Size Factor")
felzenszwalb_display_btn = gr.Button("Segment this image")
with gr.Column(scale=2):
felzenszwalb_image_output = gr.Image(label="Original Image", container=False)
felzenszwalb_segmented_image_output = gr.Image(label="felzenszwalb Segmented Image", container=False)
felzenszwalb_display_btn.click(
fn=generate_felzenszwalb,
inputs=[felzenszwalb_file_input,sigma_value,K_value,min_size_value],
outputs=[felzenszwalb_image_output,felzenszwalb_segmented_image_output]
)
with gr.TabItem("SegNet EfficientNet B0 Segmentation"):
with gr.Row():
with gr.Column(scale=1):
segnet_file_input = gr.File(label="Upload Image File")
segnet_display_btn = gr.Button("Segment this image")
with gr.Column(scale=2):
segnet_image_output = gr.Image(label="Original Image", container=False)
segnet_segmented_image_output = gr.Image(label="SegNet Segmented Image", container=False)
segnet_display_btn.click(
fn=SegNet_efficient_b0,
inputs=[segnet_file_input],
outputs=[segnet_image_output,segnet_segmented_image_output]
)
if __name__ == "__main__":
demo.launch(server_name="172.31.100.127")