muhammadhamza-stack commited on
Commit
4e5d881
·
1 Parent(s): dc83630

refine the gradio app

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ venv
app.py CHANGED
@@ -6,6 +6,69 @@ from PIL import Image
6
  import io
7
  import math
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def initialize_model(model_path):
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  model = EnhancedUNet(n_channels=1, n_classes=4).to(device)
@@ -64,7 +127,8 @@ def process_patch(patch, device):
64
  return patch_tensor.to(device)
65
 
66
  def create_overlay(original_image, mask, alpha=0.5):
67
- colors = [(0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255)] # Define colors for each class
 
68
  mask_rgb = np.zeros((*mask.shape, 3), dtype=np.uint8)
69
  for i, color in enumerate(colors):
70
  mask_rgb[mask == i] = color
@@ -77,8 +141,10 @@ def create_overlay(original_image, mask, alpha=0.5):
77
  overlay = (alpha * mask_rgb + (1 - alpha) * original_array).astype(np.uint8)
78
  return Image.fromarray(overlay)
79
 
 
80
  def predict(input_image, model_choice):
81
  if input_image is None:
 
82
  return None, None
83
 
84
  model = models[model_choice]
@@ -105,21 +171,35 @@ def predict(input_image, model_choice):
105
  full_mask = stitch_patches(predicted_patches, positions, padded_size, original_size)
106
 
107
  # Create mask image
108
- mask_image = Image.fromarray((full_mask * 63).astype(np.uint8)) # Scale for better visibility
 
109
 
110
  # Create overlay image
111
  overlay_image = create_overlay(input_image, full_mask)
112
 
113
  return mask_image, overlay_image
114
 
115
- # Initialize model (do this outside the inference function for better performance)
 
 
 
116
  w_noise_model_path = "./models/best_model_w_noise.pth"
117
  wo_noise_model_path = "./models/best_model_wo_noise.pth"
118
  w_noise_model_v2_path = "./models/best_model_w_noise_v2.pth"
119
 
120
- w_noise_model, device = initialize_model(w_noise_model_path)
121
- wo_noise_model, device = initialize_model(wo_noise_model_path)
122
- w_noise_model_v2, device = initialize_model(w_noise_model_v2_path)
 
 
 
 
 
 
 
 
 
 
123
 
124
  models = {
125
  "Without Noise": wo_noise_model,
@@ -127,21 +207,69 @@ models = {
127
  "With Noise V2": w_noise_model_v2
128
  }
129
 
130
- # Create Gradio interface
131
- iface = gr.Interface(
132
- fn=predict,
133
- inputs=[
134
- gr.Image(type="pil"),
135
- gr.Dropdown(choices=["Without Noise", "With Noise", "With Noise V2"], value="With Noise V2"),
136
- ],
137
- outputs=[
138
- gr.Image(type="pil", label="Segmentation Mask"),
139
- gr.Image(type="pil", label="Overlay"),
140
- ],
141
- title="MoS2 Image Segmentation",
142
- description="Upload an image to get the segmentation mask and overlay visualization.",
143
- examples=[["./examples/image_000003.png", "With Noise"], ["./examples/image_000005.png", "Without Noise"]],
144
- )
145
-
146
- # Launch the interface
147
- iface.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import io
7
  import math
8
 
9
+ # --- Documentation Strings ---
10
+
11
+ USAGE_GUIDELINES = """
12
+ ## 1. Quick Start Guide: Generating a Segmentation Mask
13
+ This tool analyzes your uploaded MoS2 image, breaking it down into small patches, classifying those patches using a U-Net model, and stitching the results back into a full segmentation mask.
14
+
15
+ 1. **Upload Image**: Click the image box and upload your MoS2 micrograph (PNG or JPG).
16
+ 2. **Select Model**: Choose the appropriate model weight from the dropdown (see Section 3 for differences).
17
+ 3. **Run**: Click the **"Submit"** button.
18
+ 4. **Review**: Two outputs will appear: the raw grayscale **Segmentation Mask** and the color **Overlay** (which combines the mask with the original image).
19
+ """
20
+
21
+ INPUT_EXPLANATION = """
22
+ ## 2. Input Requirements
23
+
24
+ | Input Field | Purpose | Requirement |
25
+ | :--- | :--- | :--- |
26
+ | **Input Image** | The MoS2 micrograph to be segmented. | Must be a single image file (JPG, PNG). The system automatically converts the image to **grayscale (1 channel)** before processing. |
27
+ | **Model Choice** | Selects the specific set of U-Net weights to use for inference. | Required choice among the three available options (see Model Guide below). |
28
+
29
+ ### Technical Note: Patching
30
+ This application uses a patch-based approach:
31
+ 1. The uploaded image is broken into non-overlapping **256x256 pixel patches**.
32
+ 2. Each patch is analyzed individually by the U-Net.
33
+ 3. The predicted patches are **stitched back together** to form the final segmentation map. This technique allows high-resolution images to be processed efficiently by a model trained on smaller inputs.
34
+ """
35
+
36
+ MODEL_GUIDANCE = """
37
+ ## 3. Model Selection Guidance (Without Noise vs. With Noise)
38
+
39
+ The application provides three distinct model weights, reflecting different training strategies:
40
+
41
+ | Model Option | Training Strategy | Recommended Use Case |
42
+ | :--- | :--- | :--- |
43
+ | **Without Noise** | Trained on clean, standard dataset images. | Use for high-quality, clear micrographs. Expect highly precise boundaries where the data matches the training set. |
44
+ | **With Noise** | Trained with artificial noise augmentation (e.g., Gaussian, Salt-and-Pepper). | Use for real-world images that may contain artifacts, varying light, or complex background interference. Provides better **generalization** and robustness. |
45
+ | **With Noise V2** | An updated version of the 'With Noise' model, potentially offering improved boundary definition or accuracy. | Recommended as the default choice for robust, high-performance segmentation across varied image quality. |
46
+ """
47
+
48
+ OUTPUT_INTERPRETATION = """
49
+ ## 4. Expected Outputs
50
+
51
+ The output provides two results: the raw segmentation mask and a visual overlay. The model classifies every pixel into one of **4 distinct classes (0-3)**, likely corresponding to different layers or regions of the MoS2 structure.
52
+
53
+ ### A. Segmentation Mask (Grayscale)
54
+ This image shows the raw classification output. The class index (0, 1, 2, or 3) is mapped to a grayscale intensity.
55
+ * Class 0 is represented by **Black**.
56
+ * Higher classes (1, 2, 3) are represented by progressively **lighter shades of gray**.
57
+
58
+ ### B. Overlay (Colored)
59
+ This is the most straightforward visual output, blending the original image with the color-coded mask using a default transparency (alpha).
60
+
61
+ | Color | Underlying Class Index | Possible MoS2 Region |
62
+ | :--- | :--- | :--- |
63
+ | **Black** (0, 0, 0) | Class 0 | Unlabeled Region / Background |
64
+ | **Red** (255, 0, 0) | Class 1 | Region A (e.g., Monolayer) |
65
+ | **Green** (0, 255, 0) | Class 2 | Region B (e.g., Bilayer) |
66
+ | **Blue** (0, 0, 255) | Class 3 | Region C (e.g., Bulk/Debris) |
67
+ """
68
+ # --------------------
69
+ # Core Pipeline Functions (Kept AS IS)
70
+ # --------------------
71
+
72
  def initialize_model(model_path):
73
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
  model = EnhancedUNet(n_channels=1, n_classes=4).to(device)
 
127
  return patch_tensor.to(device)
128
 
129
  def create_overlay(original_image, mask, alpha=0.5):
130
+ # Define colors for the 4 classes: Black, Red, Green, Blue
131
+ colors = [(0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255)]
132
  mask_rgb = np.zeros((*mask.shape, 3), dtype=np.uint8)
133
  for i, color in enumerate(colors):
134
  mask_rgb[mask == i] = color
 
141
  overlay = (alpha * mask_rgb + (1 - alpha) * original_array).astype(np.uint8)
142
  return Image.fromarray(overlay)
143
 
144
+ # Initialization function required for the interface handler
145
  def predict(input_image, model_choice):
146
  if input_image is None:
147
+ gr.Warning("Please upload an image or select an example.")
148
  return None, None
149
 
150
  model = models[model_choice]
 
171
  full_mask = stitch_patches(predicted_patches, positions, padded_size, original_size)
172
 
173
  # Create mask image
174
+ # Scale for better visibility (255 / 4 classes * class_index)
175
+ mask_image = Image.fromarray((full_mask * (255 // 4)).astype(np.uint8))
176
 
177
  # Create overlay image
178
  overlay_image = create_overlay(input_image, full_mask)
179
 
180
  return mask_image, overlay_image
181
 
182
+ # --------------------
183
+ # Model Initialization
184
+ # --------------------
185
+
186
  w_noise_model_path = "./models/best_model_w_noise.pth"
187
  wo_noise_model_path = "./models/best_model_wo_noise.pth"
188
  w_noise_model_v2_path = "./models/best_model_w_noise_v2.pth"
189
 
190
+ # Initialize models (assuming files exist)
191
+ try:
192
+ w_noise_model, device = initialize_model(w_noise_model_path)
193
+ wo_noise_model, device = initialize_model(wo_noise_model_path)
194
+ w_noise_model_v2, device = initialize_model(w_noise_model_v2_path)
195
+ except FileNotFoundError as e:
196
+ print(f"Warning: Model files not found. Using dummy initialization. Error: {e}")
197
+ # Fallback dummy models for interface setup if files are missing
198
+ device = torch.device("cpu")
199
+ w_noise_model = EnhancedUNet(n_channels=1, n_classes=4).to(device)
200
+ wo_noise_model = EnhancedUNet(n_channels=1, n_classes=4).to(device)
201
+ w_noise_model_v2 = EnhancedUNet(n_channels=1, n_classes=4).to(device)
202
+
203
 
204
  models = {
205
  "Without Noise": wo_noise_model,
 
207
  "With Noise V2": w_noise_model_v2
208
  }
209
 
210
+ # --------------------
211
+ # Gradio UI (Blocks Structure for Guidelines)
212
+ # --------------------
213
+
214
+ with gr.Blocks(title="MoS2 Image Segmentation") as demo:
215
+ gr.Markdown("<h1 style='text-align: center;'> MoS2 Micrograph Segmentation (U-Net Patch-Based) </h1>")
216
+ gr.Markdown("Tool for analyzing and segmenting layered Molybdenum Disulfide (MoS2) structures into 4 defined regions.")
217
+
218
+ # 1. Guidelines Accordion
219
+ with gr.Accordion("Tips, Guidelines, and Model Selection", open=False):
220
+ gr.Markdown(USAGE_GUIDELINES)
221
+ gr.Markdown("---")
222
+ gr.Markdown(INPUT_EXPLANATION)
223
+ gr.Markdown("---")
224
+ gr.Markdown(MODEL_GUIDANCE)
225
+ gr.Markdown("---")
226
+ gr.Markdown(OUTPUT_INTERPRETATION)
227
+
228
+ gr.Markdown("## Segmentation Input and Configuration")
229
+
230
+ with gr.Row():
231
+ # Input Column
232
+ with gr.Column(scale=1):
233
+ gr.Markdown("## Step 1: Upload a MoS2 Micrograph image ")
234
+ input_image = gr.Image(type="pil", label=" MoS2 Micrograph")
235
+ gr.Markdown("## Step 2: Select Model Weights ")
236
+ model_choice = gr.Dropdown(
237
+ choices=["Without Noise", "With Noise", "With Noise V2"],
238
+ value="With Noise V2",
239
+ label=" Model Weights"
240
+ )
241
+ gr.Markdown("## Step 3: Click Submit for Sugmentation ")
242
+ submit_button = gr.Button("Submit for Segmentation", variant="primary")
243
+
244
+ gr.Markdown("## Segmentation Outputs")
245
+
246
+ # Output Row
247
+ with gr.Row():
248
+ output_mask = gr.Image(type="pil", label="Step 3: Segmentation Mask (Grayscale)")
249
+ output_overlay = gr.Image(type="pil", label="Step 4: Segmentation Overlay (Color-Coded)")
250
+
251
+ # Event Handler
252
+ submit_button.click(
253
+ fn=predict,
254
+ inputs=[input_image, model_choice],
255
+ outputs=[output_mask, output_overlay]
256
+ )
257
+
258
+ # Examples Section (Must come after component definition)
259
+ gr.Markdown("---")
260
+ gr.Markdown("## Example Images")
261
+ gr.Examples(
262
+ examples=[
263
+ ["./examples/image_000003.png", "With Noise"],
264
+ ["./examples/image_000005.png", "Without Noise"]
265
+ ],
266
+ inputs=[input_image, model_choice],
267
+ outputs=[output_mask, output_overlay],
268
+ fn=predict,
269
+ cache_examples=False,
270
+ label="Click to load and run a sample image with predefined model weights.",
271
+ )
272
+
273
+
274
+ if __name__ == "__main__":
275
+ demo.launch()
examples/image_000003.png CHANGED

Git LFS Details

  • SHA256: c066c42924bc8799f454b4ef68f6c01bf37bb05a5667015517fd3b158b5102ac
  • Pointer size: 130 Bytes
  • Size of remote file: 57.9 kB
examples/image_000004.png CHANGED

Git LFS Details

  • SHA256: bd76da3648aa3ac9c7d496a265b1699d5f66946aa109b60b7a7aa575362ff026
  • Pointer size: 130 Bytes
  • Size of remote file: 46.9 kB
examples/image_000005.png CHANGED

Git LFS Details

  • SHA256: bd7924d10ebec55fe955d46e1ccfe2ed5282bb8b5e4224eff41667d12e83df3c
  • Pointer size: 130 Bytes
  • Size of remote file: 58.9 kB
examples/image_000006.png CHANGED

Git LFS Details

  • SHA256: 4809730bc8b645d7ccbfecbaaa70292971b03b29b2b53c425e7bdbe33b2c1114
  • Pointer size: 130 Bytes
  • Size of remote file: 32 kB
examples/image_000007.png CHANGED

Git LFS Details

  • SHA256: 6b8964b3377d62c10e61c20aa70c34d0e66f64dd28938f0240b22b2021b0a053
  • Pointer size: 130 Bytes
  • Size of remote file: 40.8 kB
examples/image_000029.png CHANGED

Git LFS Details

  • SHA256: 804ac0f1d0e5553dd4fcda10ff49f8281369e0418c17ff599eb76f8a27ef879f
  • Pointer size: 130 Bytes
  • Size of remote file: 27.9 kB
examples/image_000030.png CHANGED

Git LFS Details

  • SHA256: 4d1a71b8b5fa7ca82bd18791dffa77c4115ba2ebe4698252230ac8ad53fbb6e1
  • Pointer size: 130 Bytes
  • Size of remote file: 27.6 kB
examples/image_000031.png CHANGED

Git LFS Details

  • SHA256: 51f33e16b6c2e339a3f0576dc3d12c7c79af88733388223086643cad046ccde2
  • Pointer size: 130 Bytes
  • Size of remote file: 38.1 kB
examples/image_000032.png CHANGED

Git LFS Details

  • SHA256: 57f37b5985ab78b97e7898d552ce80ea1c0d1ee40be12a3347fde5be724a4a27
  • Pointer size: 130 Bytes
  • Size of remote file: 25.5 kB
examples/image_000033.png CHANGED

Git LFS Details

  • SHA256: 5c899f21221c6b5ca99a08df91e6c5facc25a4e6ac8e8bd85c306a6ee8126a54
  • Pointer size: 130 Bytes
  • Size of remote file: 38.2 kB
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  torch
2
- gradio
3
  pillow
 
 
 
1
  torch
 
2
  pillow
3
+ gradio==3.50.2
4
+ gradio-client==0.6.1