File size: 11,444 Bytes
efd5df3
 
 
 
 
 
 
 
4e5d881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efd5df3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e5d881
 
efd5df3
 
 
 
 
 
 
 
 
 
 
 
4e5d881
efd5df3
 
4e5d881
efd5df3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e5d881
 
efd5df3
 
 
 
 
 
4e5d881
 
 
 
efd5df3
 
c895613
efd5df3
4e5d881
 
 
 
 
 
 
 
 
 
 
 
 
efd5df3
 
 
dc83630
c895613
efd5df3
 
4e5d881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d08406f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import gradio as gr
import torch
from unet import EnhancedUNet
import numpy as np
from PIL import Image
import io
import math

# --- Documentation Strings ---

USAGE_GUIDELINES = """
## 1. Quick Start Guide: Generating a Segmentation Mask
This tool analyzes your uploaded MoS2 image, breaking it down into small patches, classifying those patches using a U-Net model, and stitching the results back into a full segmentation mask.

1.  **Upload Image**: Click the image box and upload your MoS2 micrograph (PNG or JPG).
2.  **Select Model**: Choose the appropriate model weight from the dropdown (see Section 3 for differences).
3.  **Run**: Click the **"Submit"** button.
4.  **Review**: Two outputs will appear: the raw grayscale **Segmentation Mask** and the color **Overlay** (which combines the mask with the original image).
"""

INPUT_EXPLANATION = """
## 2. Input Requirements 

| Input Field | Purpose | Requirement |
| :--- | :--- | :--- |
| **Input Image** | The MoS2 micrograph to be segmented. | Must be a single image file (JPG, PNG). The system automatically converts the image to **grayscale (1 channel)** before processing. |
| **Model Choice** | Selects the specific set of U-Net weights to use for inference. | Required choice among the three available options (see Model Guide below). |

### Technical Note: Patching
This application uses a patch-based approach:
1.  The uploaded image is broken into non-overlapping **256x256 pixel patches**.
2.  Each patch is analyzed individually by the U-Net.
3.  The predicted patches are **stitched back together** to form the final segmentation map. This technique allows high-resolution images to be processed efficiently by a model trained on smaller inputs.
"""

MODEL_GUIDANCE = """
## 3. Model Selection Guidance (Without Noise vs. With Noise)

The application provides three distinct model weights, reflecting different training strategies:

| Model Option | Training Strategy | Recommended Use Case |
| :--- | :--- | :--- |
| **Without Noise** | Trained on clean, standard dataset images. | Use for high-quality, clear micrographs. Expect highly precise boundaries where the data matches the training set. |
| **With Noise** | Trained with artificial noise augmentation (e.g., Gaussian, Salt-and-Pepper). | Use for real-world images that may contain artifacts, varying light, or complex background interference. Provides better **generalization** and robustness. |
| **With Noise V2** | An updated version of the 'With Noise' model, potentially offering improved boundary definition or accuracy. | Recommended as the default choice for robust, high-performance segmentation across varied image quality. |
"""

OUTPUT_INTERPRETATION = """
## 4. Expected Outputs 

The output provides two results: the raw segmentation mask and a visual overlay. The model classifies every pixel into one of **4 distinct classes (0-3)**, likely corresponding to different layers or regions of the MoS2 structure.

### A. Segmentation Mask (Grayscale)
This image shows the raw classification output. The class index (0, 1, 2, or 3) is mapped to a grayscale intensity.
*   Class 0 is represented by **Black**.
*   Higher classes (1, 2, 3) are represented by progressively **lighter shades of gray**.

### B. Overlay (Colored)
This is the most straightforward visual output, blending the original image with the color-coded mask using a default transparency (alpha).

| Color | Underlying Class Index | Possible MoS2 Region |
| :--- | :--- | :--- |
| **Black** (0, 0, 0) | Class 0 | Unlabeled Region / Background |
| **Red** (255, 0, 0) | Class 1 | Region A (e.g., Monolayer) |
| **Green** (0, 255, 0) | Class 2 | Region B (e.g., Bilayer) |
| **Blue** (0, 0, 255) | Class 3 | Region C (e.g., Bulk/Debris) |
"""
# --------------------
# Core Pipeline Functions (Kept AS IS)
# --------------------

def initialize_model(model_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = EnhancedUNet(n_channels=1, n_classes=4).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model, device

def extract_patches(image, patch_size=256):
    """Extract patches from the input image."""
    width, height = image.size
    patches = []
    positions = []
    
    # Calculate number of patches in each dimension
    n_cols = math.ceil(width / patch_size)
    n_rows = math.ceil(height / patch_size)
    
    # Pad image if necessary
    padded_width = n_cols * patch_size
    padded_height = n_rows * patch_size
    padded_image = Image.new('L', (padded_width, padded_height), 0)
    padded_image.paste(image, (0, 0))
    
    # Extract patches
    for i in range(n_rows):
        for j in range(n_cols):
            left = j * patch_size
            top = i * patch_size
            right = left + patch_size
            bottom = top + patch_size
            
            patch = padded_image.crop((left, top, right, bottom))
            patches.append(patch)
            positions.append((left, top, right, bottom))
    
    return patches, positions, (padded_width, padded_height), (width, height)

def stitch_patches(patches, positions, padded_size, original_size, n_classes=4):
    """Stitch patches back together into a complete mask."""
    full_mask = np.zeros((padded_size[1], padded_size[0]), dtype=np.uint8)
    
    for patch, (left, top, right, bottom) in zip(patches, positions):
        full_mask[top:bottom, left:right] = patch
        
    # Crop back to original size
    full_mask = full_mask[:original_size[1], :original_size[0]]
    return full_mask

def process_patch(patch, device):
    # Convert to grayscale if it's not already
    patch_gray = patch.convert("L")
    # Convert to numpy array and normalize
    patch_np = np.array(patch_gray, dtype=np.float32) / 255.0
    # Add batch and channel dimensions
    patch_tensor = torch.from_numpy(patch_np).float().unsqueeze(0).unsqueeze(0)
    return patch_tensor.to(device)

def create_overlay(original_image, mask, alpha=0.5):
    # Define colors for the 4 classes: Black, Red, Green, Blue
    colors = [(0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255)]  
    mask_rgb = np.zeros((*mask.shape, 3), dtype=np.uint8)
    for i, color in enumerate(colors):
        mask_rgb[mask == i] = color

    # Resize original image to match mask size
    original_image = original_image.resize((mask.shape[1], mask.shape[0]))
    original_array = np.array(original_image.convert("RGB"))

    # Create overlay
    overlay = (alpha * mask_rgb + (1 - alpha) * original_array).astype(np.uint8)
    return Image.fromarray(overlay)

# Initialization function required for the interface handler
def predict(input_image, model_choice):
    if input_image is None:
        gr.Warning("Please upload an image or select an example.")
        return None, None
        
    model = models[model_choice]
    patch_size = 256
    
    # Extract patches
    patches, positions, padded_size, original_size = extract_patches(input_image, patch_size)
    
    # Process each patch
    predicted_patches = []
    for patch in patches:
        # Process patch
        patch_tensor = process_patch(patch, device)
        
        # Perform inference
        with torch.no_grad():
            output = model(patch_tensor)
        
        # Get prediction mask for patch
        patch_mask = torch.argmax(output, dim=1).cpu().numpy()[0]
        predicted_patches.append(patch_mask)
    
    # Stitch patches back together
    full_mask = stitch_patches(predicted_patches, positions, padded_size, original_size)
    
    # Create mask image
    # Scale for better visibility (255 / 4 classes * class_index)
    mask_image = Image.fromarray((full_mask * (255 // 4)).astype(np.uint8))  
    
    # Create overlay image
    overlay_image = create_overlay(input_image, full_mask)
    
    return mask_image, overlay_image

# --------------------
# Model Initialization
# --------------------

w_noise_model_path = "./models/best_model_w_noise.pth"
wo_noise_model_path = "./models/best_model_wo_noise.pth"
w_noise_model_v2_path = "./models/best_model_w_noise_v2.pth"

# Initialize models (assuming files exist)
try:
    w_noise_model, device = initialize_model(w_noise_model_path)
    wo_noise_model, device = initialize_model(wo_noise_model_path)
    w_noise_model_v2, device = initialize_model(w_noise_model_v2_path)
except FileNotFoundError as e:
    print(f"Warning: Model files not found. Using dummy initialization. Error: {e}")
    # Fallback dummy models for interface setup if files are missing
    device = torch.device("cpu")
    w_noise_model = EnhancedUNet(n_channels=1, n_classes=4).to(device)
    wo_noise_model = EnhancedUNet(n_channels=1, n_classes=4).to(device)
    w_noise_model_v2 = EnhancedUNet(n_channels=1, n_classes=4).to(device)


models = {
    "Without Noise": wo_noise_model,
    "With Noise": w_noise_model,
    "With Noise V2": w_noise_model_v2
}

# --------------------
# Gradio UI (Blocks Structure for Guidelines)
# --------------------

with gr.Blocks(title="MoS2 Image Segmentation") as demo:
    gr.Markdown("<h1 style='text-align: center;'> MoS2 Micrograph Segmentation (U-Net Patch-Based) </h1>")
    gr.Markdown("Tool for analyzing and segmenting layered Molybdenum Disulfide (MoS2) structures into 4 defined regions.")

    # 1. Guidelines Accordion
    with gr.Accordion("Tips, Guidelines, and Model Selection", open=False):
        gr.Markdown(USAGE_GUIDELINES)
        gr.Markdown("---")
        gr.Markdown(INPUT_EXPLANATION)
        gr.Markdown("---")
        gr.Markdown(MODEL_GUIDANCE)
        gr.Markdown("---")
        gr.Markdown(OUTPUT_INTERPRETATION)

    gr.Markdown("## Segmentation Input and Configuration")
    
    with gr.Row():
        # Input Column
        with gr.Column(scale=1):
            gr.Markdown("## Step 1: Upload a MoS2 Micrograph image ")
            input_image = gr.Image(type="pil", label=" MoS2 Micrograph")
            gr.Markdown("## Step 2: Select Model Weights ")
            model_choice = gr.Dropdown(
                choices=["Without Noise", "With Noise", "With Noise V2"], 
                value="With Noise V2", 
                label=" Model Weights"
            )
            gr.Markdown("## Step 3: Click Submit for Sugmentation ")
            submit_button = gr.Button("Submit for Segmentation", variant="primary")
            
    gr.Markdown("## Segmentation Outputs")

    # Output Row
    with gr.Row():
        output_mask = gr.Image(type="pil", label="Step 3: Segmentation Mask (Grayscale)")
        output_overlay = gr.Image(type="pil", label="Step 4: Segmentation Overlay (Color-Coded)")

    # Event Handler
    submit_button.click(
        fn=predict,
        inputs=[input_image, model_choice],
        outputs=[output_mask, output_overlay]
    )

    # Examples Section (Must come after component definition)
    gr.Markdown("---")
    gr.Markdown("## Example Images")
    gr.Examples(
        examples=[
            ["./examples/image_000003.png", "With Noise"], 
            ["./examples/image_000005.png", "Without Noise"]
        ],
        inputs=[input_image, model_choice],
        outputs=[output_mask, output_overlay],
        fn=predict,
        cache_examples=False,
        label="Click to load and run a sample image with predefined model weights.",
    )


if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860
    )