Spaces:
Build error
Build error
File size: 5,611 Bytes
0f9608b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import torch
import numpy as np
from PIL import Image
import cv2
from torchvision import transforms
from experiments.otsu_segmenter import otsu_threshold
from experiments.SegNet.efficient_b0_backbone.architecture import SegNetEfficientNet, NUM_CLASSES, DEVICE
def ensemble_segmentation(image_path, model_path="segnet_efficientnet_voc.pth", boundary_weight=0.3):
"""
Ensemble segmentation combining Otsu thresholding and SegNet
Args:
image_path: Path to input image
model_path: Path to SegNet model weights
boundary_weight: Weight for boundary refinement (0-1)
Returns:
original_image: Original input image (PIL)
ensemble_result: Ensemble segmentation result (PIL)
method_comparison: Visualization of all methods side by side (PIL)
"""
# 1. Load the image
image = Image.open(image_path).convert('RGB')
original = image.copy()
image_np = np.array(image)
# 2. Run Otsu thresholding for boundary detection
# Convert to grayscale and apply Gaussian blur
gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(gray, cv2.COLOR_BGR2GRAY)
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
otsu_threshold_value, otsu_mask = otsu_threshold(blurred)
# 3. Run SegNet for semantic segmentation
model = SegNetEfficientNet(NUM_CLASSES).to(DEVICE)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()
transform = transforms.Compose([
transforms.Resize((360, 480)), # Or larger if needed
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
input_tensor = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(input_tensor)
segnet_pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
# 4. Create edge map from Otsu result
edges = cv2.Canny(otsu_mask, 50, 150)
# Resize to match SegNet output size
edges_resized = cv2.resize(edges, (segnet_pred.shape[1], segnet_pred.shape[0]),
interpolation=cv2.INTER_NEAREST)
# 5. Ensemble: Use Otsu edges to refine SegNet boundaries
# Create a distance transform from the edges
dist_transform = cv2.distanceTransform(255 - edges_resized, cv2.DIST_L2, 5)
dist_transform = dist_transform / dist_transform.max() # Normalize to 0-1
# Areas close to edges get more influence from Otsu
edge_weight_map = np.exp(-dist_transform * 5) * boundary_weight
# Create binary mask from SegNet (foreground = any class other than background)
segnet_binary = (segnet_pred > 0).astype(np.uint8) * 255
# Resize Otsu mask to match SegNet output
otsu_resized = cv2.resize(otsu_mask, (segnet_pred.shape[1], segnet_pred.shape[0]),
interpolation=cv2.INTER_NEAREST)
# Combine: Use SegNet classes but refine boundaries with Otsu
# For boundary regions, adjust the segmentation based on Otsu
refined_binary = segnet_binary.copy()
boundary_region = edge_weight_map > 0.1
refined_binary[boundary_region] = (
(1 - edge_weight_map[boundary_region]) * segnet_binary[boundary_region] +
edge_weight_map[boundary_region] * otsu_resized[boundary_region]
).astype(np.uint8)
# Apply the refined binary mask to the original SegNet prediction
ensemble_result = segnet_pred.copy()
# Where the refined binary is 0, set to background class (0)
ensemble_result[refined_binary < 128] = 0
# 6. Visualize results
from matplotlib import cm
import matplotlib.pyplot as plt
import io
# Convert semantic maps to color visualizations
colormap = cm.get_cmap('nipy_spectral')
segnet_colored = colormap(segnet_pred / (NUM_CLASSES - 1))
segnet_colored = (segnet_colored[:, :, :3] * 255).astype(np.uint8)
ensemble_colored = colormap(ensemble_result / (NUM_CLASSES - 1))
ensemble_colored = (ensemble_colored[:, :, :3] * 255).astype(np.uint8)
# Create side-by-side comparison
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
# Resize original image to match the segmentation size
original_resized = original.resize((segnet_pred.shape[1], segnet_pred.shape[0]))
axes[0].imshow(original_resized)
axes[0].set_title("Original Image")
axes[0].axis('off')
axes[1].imshow(otsu_mask, cmap='gray')
axes[1].set_title(f"Otsu (t={otsu_threshold_value})")
axes[1].axis('off')
axes[2].imshow(segnet_colored)
axes[2].set_title("SegNet Prediction")
axes[2].axis('off')
axes[3].imshow(ensemble_colored)
axes[3].set_title("Ensemble Result")
axes[3].axis('off')
plt.tight_layout()
# Convert the plot to an image
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
comparison_image = Image.open(buf)
plt.close(fig)
# Return results
ensemble_pil = Image.fromarray(ensemble_colored)
ensemble_pil = ensemble_pil.resize(original.size, Image.NEAREST)
return original, ensemble_pil, comparison_image
# Add this function to your app.py
def generate_ensemble_segmentation(image_path, boundary_weight=0.3):
"""Wrapper for Gradio interface"""
original, ensemble_result, comparison = ensemble_segmentation(
image_path,
model_path="saved_models/segnet_efficientnet_camvid.pth",
boundary_weight=boundary_weight
)
return original, ensemble_result, comparison |