Spaces:
Paused
Paused
File size: 11,444 Bytes
efd5df3 4e5d881 efd5df3 4e5d881 efd5df3 4e5d881 efd5df3 4e5d881 efd5df3 4e5d881 efd5df3 4e5d881 efd5df3 c895613 efd5df3 4e5d881 efd5df3 dc83630 c895613 efd5df3 4e5d881 d08406f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 |
import gradio as gr
import torch
from unet import EnhancedUNet
import numpy as np
from PIL import Image
import io
import math
# --- Documentation Strings ---
USAGE_GUIDELINES = """
## 1. Quick Start Guide: Generating a Segmentation Mask
This tool analyzes your uploaded MoS2 image, breaking it down into small patches, classifying those patches using a U-Net model, and stitching the results back into a full segmentation mask.
1. **Upload Image**: Click the image box and upload your MoS2 micrograph (PNG or JPG).
2. **Select Model**: Choose the appropriate model weight from the dropdown (see Section 3 for differences).
3. **Run**: Click the **"Submit"** button.
4. **Review**: Two outputs will appear: the raw grayscale **Segmentation Mask** and the color **Overlay** (which combines the mask with the original image).
"""
INPUT_EXPLANATION = """
## 2. Input Requirements
| Input Field | Purpose | Requirement |
| :--- | :--- | :--- |
| **Input Image** | The MoS2 micrograph to be segmented. | Must be a single image file (JPG, PNG). The system automatically converts the image to **grayscale (1 channel)** before processing. |
| **Model Choice** | Selects the specific set of U-Net weights to use for inference. | Required choice among the three available options (see Model Guide below). |
### Technical Note: Patching
This application uses a patch-based approach:
1. The uploaded image is broken into non-overlapping **256x256 pixel patches**.
2. Each patch is analyzed individually by the U-Net.
3. The predicted patches are **stitched back together** to form the final segmentation map. This technique allows high-resolution images to be processed efficiently by a model trained on smaller inputs.
"""
MODEL_GUIDANCE = """
## 3. Model Selection Guidance (Without Noise vs. With Noise)
The application provides three distinct model weights, reflecting different training strategies:
| Model Option | Training Strategy | Recommended Use Case |
| :--- | :--- | :--- |
| **Without Noise** | Trained on clean, standard dataset images. | Use for high-quality, clear micrographs. Expect highly precise boundaries where the data matches the training set. |
| **With Noise** | Trained with artificial noise augmentation (e.g., Gaussian, Salt-and-Pepper). | Use for real-world images that may contain artifacts, varying light, or complex background interference. Provides better **generalization** and robustness. |
| **With Noise V2** | An updated version of the 'With Noise' model, potentially offering improved boundary definition or accuracy. | Recommended as the default choice for robust, high-performance segmentation across varied image quality. |
"""
OUTPUT_INTERPRETATION = """
## 4. Expected Outputs
The output provides two results: the raw segmentation mask and a visual overlay. The model classifies every pixel into one of **4 distinct classes (0-3)**, likely corresponding to different layers or regions of the MoS2 structure.
### A. Segmentation Mask (Grayscale)
This image shows the raw classification output. The class index (0, 1, 2, or 3) is mapped to a grayscale intensity.
* Class 0 is represented by **Black**.
* Higher classes (1, 2, 3) are represented by progressively **lighter shades of gray**.
### B. Overlay (Colored)
This is the most straightforward visual output, blending the original image with the color-coded mask using a default transparency (alpha).
| Color | Underlying Class Index | Possible MoS2 Region |
| :--- | :--- | :--- |
| **Black** (0, 0, 0) | Class 0 | Unlabeled Region / Background |
| **Red** (255, 0, 0) | Class 1 | Region A (e.g., Monolayer) |
| **Green** (0, 255, 0) | Class 2 | Region B (e.g., Bilayer) |
| **Blue** (0, 0, 255) | Class 3 | Region C (e.g., Bulk/Debris) |
"""
# --------------------
# Core Pipeline Functions (Kept AS IS)
# --------------------
def initialize_model(model_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EnhancedUNet(n_channels=1, n_classes=4).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model, device
def extract_patches(image, patch_size=256):
"""Extract patches from the input image."""
width, height = image.size
patches = []
positions = []
# Calculate number of patches in each dimension
n_cols = math.ceil(width / patch_size)
n_rows = math.ceil(height / patch_size)
# Pad image if necessary
padded_width = n_cols * patch_size
padded_height = n_rows * patch_size
padded_image = Image.new('L', (padded_width, padded_height), 0)
padded_image.paste(image, (0, 0))
# Extract patches
for i in range(n_rows):
for j in range(n_cols):
left = j * patch_size
top = i * patch_size
right = left + patch_size
bottom = top + patch_size
patch = padded_image.crop((left, top, right, bottom))
patches.append(patch)
positions.append((left, top, right, bottom))
return patches, positions, (padded_width, padded_height), (width, height)
def stitch_patches(patches, positions, padded_size, original_size, n_classes=4):
"""Stitch patches back together into a complete mask."""
full_mask = np.zeros((padded_size[1], padded_size[0]), dtype=np.uint8)
for patch, (left, top, right, bottom) in zip(patches, positions):
full_mask[top:bottom, left:right] = patch
# Crop back to original size
full_mask = full_mask[:original_size[1], :original_size[0]]
return full_mask
def process_patch(patch, device):
# Convert to grayscale if it's not already
patch_gray = patch.convert("L")
# Convert to numpy array and normalize
patch_np = np.array(patch_gray, dtype=np.float32) / 255.0
# Add batch and channel dimensions
patch_tensor = torch.from_numpy(patch_np).float().unsqueeze(0).unsqueeze(0)
return patch_tensor.to(device)
def create_overlay(original_image, mask, alpha=0.5):
# Define colors for the 4 classes: Black, Red, Green, Blue
colors = [(0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255)]
mask_rgb = np.zeros((*mask.shape, 3), dtype=np.uint8)
for i, color in enumerate(colors):
mask_rgb[mask == i] = color
# Resize original image to match mask size
original_image = original_image.resize((mask.shape[1], mask.shape[0]))
original_array = np.array(original_image.convert("RGB"))
# Create overlay
overlay = (alpha * mask_rgb + (1 - alpha) * original_array).astype(np.uint8)
return Image.fromarray(overlay)
# Initialization function required for the interface handler
def predict(input_image, model_choice):
if input_image is None:
gr.Warning("Please upload an image or select an example.")
return None, None
model = models[model_choice]
patch_size = 256
# Extract patches
patches, positions, padded_size, original_size = extract_patches(input_image, patch_size)
# Process each patch
predicted_patches = []
for patch in patches:
# Process patch
patch_tensor = process_patch(patch, device)
# Perform inference
with torch.no_grad():
output = model(patch_tensor)
# Get prediction mask for patch
patch_mask = torch.argmax(output, dim=1).cpu().numpy()[0]
predicted_patches.append(patch_mask)
# Stitch patches back together
full_mask = stitch_patches(predicted_patches, positions, padded_size, original_size)
# Create mask image
# Scale for better visibility (255 / 4 classes * class_index)
mask_image = Image.fromarray((full_mask * (255 // 4)).astype(np.uint8))
# Create overlay image
overlay_image = create_overlay(input_image, full_mask)
return mask_image, overlay_image
# --------------------
# Model Initialization
# --------------------
w_noise_model_path = "./models/best_model_w_noise.pth"
wo_noise_model_path = "./models/best_model_wo_noise.pth"
w_noise_model_v2_path = "./models/best_model_w_noise_v2.pth"
# Initialize models (assuming files exist)
try:
w_noise_model, device = initialize_model(w_noise_model_path)
wo_noise_model, device = initialize_model(wo_noise_model_path)
w_noise_model_v2, device = initialize_model(w_noise_model_v2_path)
except FileNotFoundError as e:
print(f"Warning: Model files not found. Using dummy initialization. Error: {e}")
# Fallback dummy models for interface setup if files are missing
device = torch.device("cpu")
w_noise_model = EnhancedUNet(n_channels=1, n_classes=4).to(device)
wo_noise_model = EnhancedUNet(n_channels=1, n_classes=4).to(device)
w_noise_model_v2 = EnhancedUNet(n_channels=1, n_classes=4).to(device)
models = {
"Without Noise": wo_noise_model,
"With Noise": w_noise_model,
"With Noise V2": w_noise_model_v2
}
# --------------------
# Gradio UI (Blocks Structure for Guidelines)
# --------------------
with gr.Blocks(title="MoS2 Image Segmentation") as demo:
gr.Markdown("<h1 style='text-align: center;'> MoS2 Micrograph Segmentation (U-Net Patch-Based) </h1>")
gr.Markdown("Tool for analyzing and segmenting layered Molybdenum Disulfide (MoS2) structures into 4 defined regions.")
# 1. Guidelines Accordion
with gr.Accordion("Tips, Guidelines, and Model Selection", open=False):
gr.Markdown(USAGE_GUIDELINES)
gr.Markdown("---")
gr.Markdown(INPUT_EXPLANATION)
gr.Markdown("---")
gr.Markdown(MODEL_GUIDANCE)
gr.Markdown("---")
gr.Markdown(OUTPUT_INTERPRETATION)
gr.Markdown("## Segmentation Input and Configuration")
with gr.Row():
# Input Column
with gr.Column(scale=1):
gr.Markdown("## Step 1: Upload a MoS2 Micrograph image ")
input_image = gr.Image(type="pil", label=" MoS2 Micrograph")
gr.Markdown("## Step 2: Select Model Weights ")
model_choice = gr.Dropdown(
choices=["Without Noise", "With Noise", "With Noise V2"],
value="With Noise V2",
label=" Model Weights"
)
gr.Markdown("## Step 3: Click Submit for Sugmentation ")
submit_button = gr.Button("Submit for Segmentation", variant="primary")
gr.Markdown("## Segmentation Outputs")
# Output Row
with gr.Row():
output_mask = gr.Image(type="pil", label="Step 3: Segmentation Mask (Grayscale)")
output_overlay = gr.Image(type="pil", label="Step 4: Segmentation Overlay (Color-Coded)")
# Event Handler
submit_button.click(
fn=predict,
inputs=[input_image, model_choice],
outputs=[output_mask, output_overlay]
)
# Examples Section (Must come after component definition)
gr.Markdown("---")
gr.Markdown("## Example Images")
gr.Examples(
examples=[
["./examples/image_000003.png", "With Noise"],
["./examples/image_000005.png", "Without Noise"]
],
inputs=[input_image, model_choice],
outputs=[output_mask, output_overlay],
fn=predict,
cache_examples=False,
label="Click to load and run a sample image with predefined model weights.",
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860
) |