muhammadhamza-stack
add dockerfile
d08406f
import gradio as gr
import torch
from unet import EnhancedUNet
import numpy as np
from PIL import Image
import io
import math
# --- Documentation Strings ---
USAGE_GUIDELINES = """
## 1. Quick Start Guide: Generating a Segmentation Mask
This tool analyzes your uploaded MoS2 image, breaking it down into small patches, classifying those patches using a U-Net model, and stitching the results back into a full segmentation mask.
1. **Upload Image**: Click the image box and upload your MoS2 micrograph (PNG or JPG).
2. **Select Model**: Choose the appropriate model weight from the dropdown (see Section 3 for differences).
3. **Run**: Click the **"Submit"** button.
4. **Review**: Two outputs will appear: the raw grayscale **Segmentation Mask** and the color **Overlay** (which combines the mask with the original image).
"""
INPUT_EXPLANATION = """
## 2. Input Requirements
| Input Field | Purpose | Requirement |
| :--- | :--- | :--- |
| **Input Image** | The MoS2 micrograph to be segmented. | Must be a single image file (JPG, PNG). The system automatically converts the image to **grayscale (1 channel)** before processing. |
| **Model Choice** | Selects the specific set of U-Net weights to use for inference. | Required choice among the three available options (see Model Guide below). |
### Technical Note: Patching
This application uses a patch-based approach:
1. The uploaded image is broken into non-overlapping **256x256 pixel patches**.
2. Each patch is analyzed individually by the U-Net.
3. The predicted patches are **stitched back together** to form the final segmentation map. This technique allows high-resolution images to be processed efficiently by a model trained on smaller inputs.
"""
MODEL_GUIDANCE = """
## 3. Model Selection Guidance (Without Noise vs. With Noise)
The application provides three distinct model weights, reflecting different training strategies:
| Model Option | Training Strategy | Recommended Use Case |
| :--- | :--- | :--- |
| **Without Noise** | Trained on clean, standard dataset images. | Use for high-quality, clear micrographs. Expect highly precise boundaries where the data matches the training set. |
| **With Noise** | Trained with artificial noise augmentation (e.g., Gaussian, Salt-and-Pepper). | Use for real-world images that may contain artifacts, varying light, or complex background interference. Provides better **generalization** and robustness. |
| **With Noise V2** | An updated version of the 'With Noise' model, potentially offering improved boundary definition or accuracy. | Recommended as the default choice for robust, high-performance segmentation across varied image quality. |
"""
OUTPUT_INTERPRETATION = """
## 4. Expected Outputs
The output provides two results: the raw segmentation mask and a visual overlay. The model classifies every pixel into one of **4 distinct classes (0-3)**, likely corresponding to different layers or regions of the MoS2 structure.
### A. Segmentation Mask (Grayscale)
This image shows the raw classification output. The class index (0, 1, 2, or 3) is mapped to a grayscale intensity.
* Class 0 is represented by **Black**.
* Higher classes (1, 2, 3) are represented by progressively **lighter shades of gray**.
### B. Overlay (Colored)
This is the most straightforward visual output, blending the original image with the color-coded mask using a default transparency (alpha).
| Color | Underlying Class Index | Possible MoS2 Region |
| :--- | :--- | :--- |
| **Black** (0, 0, 0) | Class 0 | Unlabeled Region / Background |
| **Red** (255, 0, 0) | Class 1 | Region A (e.g., Monolayer) |
| **Green** (0, 255, 0) | Class 2 | Region B (e.g., Bilayer) |
| **Blue** (0, 0, 255) | Class 3 | Region C (e.g., Bulk/Debris) |
"""
# --------------------
# Core Pipeline Functions (Kept AS IS)
# --------------------
def initialize_model(model_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EnhancedUNet(n_channels=1, n_classes=4).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model, device
def extract_patches(image, patch_size=256):
"""Extract patches from the input image."""
width, height = image.size
patches = []
positions = []
# Calculate number of patches in each dimension
n_cols = math.ceil(width / patch_size)
n_rows = math.ceil(height / patch_size)
# Pad image if necessary
padded_width = n_cols * patch_size
padded_height = n_rows * patch_size
padded_image = Image.new('L', (padded_width, padded_height), 0)
padded_image.paste(image, (0, 0))
# Extract patches
for i in range(n_rows):
for j in range(n_cols):
left = j * patch_size
top = i * patch_size
right = left + patch_size
bottom = top + patch_size
patch = padded_image.crop((left, top, right, bottom))
patches.append(patch)
positions.append((left, top, right, bottom))
return patches, positions, (padded_width, padded_height), (width, height)
def stitch_patches(patches, positions, padded_size, original_size, n_classes=4):
"""Stitch patches back together into a complete mask."""
full_mask = np.zeros((padded_size[1], padded_size[0]), dtype=np.uint8)
for patch, (left, top, right, bottom) in zip(patches, positions):
full_mask[top:bottom, left:right] = patch
# Crop back to original size
full_mask = full_mask[:original_size[1], :original_size[0]]
return full_mask
def process_patch(patch, device):
# Convert to grayscale if it's not already
patch_gray = patch.convert("L")
# Convert to numpy array and normalize
patch_np = np.array(patch_gray, dtype=np.float32) / 255.0
# Add batch and channel dimensions
patch_tensor = torch.from_numpy(patch_np).float().unsqueeze(0).unsqueeze(0)
return patch_tensor.to(device)
def create_overlay(original_image, mask, alpha=0.5):
# Define colors for the 4 classes: Black, Red, Green, Blue
colors = [(0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255)]
mask_rgb = np.zeros((*mask.shape, 3), dtype=np.uint8)
for i, color in enumerate(colors):
mask_rgb[mask == i] = color
# Resize original image to match mask size
original_image = original_image.resize((mask.shape[1], mask.shape[0]))
original_array = np.array(original_image.convert("RGB"))
# Create overlay
overlay = (alpha * mask_rgb + (1 - alpha) * original_array).astype(np.uint8)
return Image.fromarray(overlay)
# Initialization function required for the interface handler
def predict(input_image, model_choice):
if input_image is None:
gr.Warning("Please upload an image or select an example.")
return None, None
model = models[model_choice]
patch_size = 256
# Extract patches
patches, positions, padded_size, original_size = extract_patches(input_image, patch_size)
# Process each patch
predicted_patches = []
for patch in patches:
# Process patch
patch_tensor = process_patch(patch, device)
# Perform inference
with torch.no_grad():
output = model(patch_tensor)
# Get prediction mask for patch
patch_mask = torch.argmax(output, dim=1).cpu().numpy()[0]
predicted_patches.append(patch_mask)
# Stitch patches back together
full_mask = stitch_patches(predicted_patches, positions, padded_size, original_size)
# Create mask image
# Scale for better visibility (255 / 4 classes * class_index)
mask_image = Image.fromarray((full_mask * (255 // 4)).astype(np.uint8))
# Create overlay image
overlay_image = create_overlay(input_image, full_mask)
return mask_image, overlay_image
# --------------------
# Model Initialization
# --------------------
w_noise_model_path = "./models/best_model_w_noise.pth"
wo_noise_model_path = "./models/best_model_wo_noise.pth"
w_noise_model_v2_path = "./models/best_model_w_noise_v2.pth"
# Initialize models (assuming files exist)
try:
w_noise_model, device = initialize_model(w_noise_model_path)
wo_noise_model, device = initialize_model(wo_noise_model_path)
w_noise_model_v2, device = initialize_model(w_noise_model_v2_path)
except FileNotFoundError as e:
print(f"Warning: Model files not found. Using dummy initialization. Error: {e}")
# Fallback dummy models for interface setup if files are missing
device = torch.device("cpu")
w_noise_model = EnhancedUNet(n_channels=1, n_classes=4).to(device)
wo_noise_model = EnhancedUNet(n_channels=1, n_classes=4).to(device)
w_noise_model_v2 = EnhancedUNet(n_channels=1, n_classes=4).to(device)
models = {
"Without Noise": wo_noise_model,
"With Noise": w_noise_model,
"With Noise V2": w_noise_model_v2
}
# --------------------
# Gradio UI (Blocks Structure for Guidelines)
# --------------------
with gr.Blocks(title="MoS2 Image Segmentation") as demo:
gr.Markdown("<h1 style='text-align: center;'> MoS2 Micrograph Segmentation (U-Net Patch-Based) </h1>")
gr.Markdown("Tool for analyzing and segmenting layered Molybdenum Disulfide (MoS2) structures into 4 defined regions.")
# 1. Guidelines Accordion
with gr.Accordion("Tips, Guidelines, and Model Selection", open=False):
gr.Markdown(USAGE_GUIDELINES)
gr.Markdown("---")
gr.Markdown(INPUT_EXPLANATION)
gr.Markdown("---")
gr.Markdown(MODEL_GUIDANCE)
gr.Markdown("---")
gr.Markdown(OUTPUT_INTERPRETATION)
gr.Markdown("## Segmentation Input and Configuration")
with gr.Row():
# Input Column
with gr.Column(scale=1):
gr.Markdown("## Step 1: Upload a MoS2 Micrograph image ")
input_image = gr.Image(type="pil", label=" MoS2 Micrograph")
gr.Markdown("## Step 2: Select Model Weights ")
model_choice = gr.Dropdown(
choices=["Without Noise", "With Noise", "With Noise V2"],
value="With Noise V2",
label=" Model Weights"
)
gr.Markdown("## Step 3: Click Submit for Sugmentation ")
submit_button = gr.Button("Submit for Segmentation", variant="primary")
gr.Markdown("## Segmentation Outputs")
# Output Row
with gr.Row():
output_mask = gr.Image(type="pil", label="Step 3: Segmentation Mask (Grayscale)")
output_overlay = gr.Image(type="pil", label="Step 4: Segmentation Overlay (Color-Coded)")
# Event Handler
submit_button.click(
fn=predict,
inputs=[input_image, model_choice],
outputs=[output_mask, output_overlay]
)
# Examples Section (Must come after component definition)
gr.Markdown("---")
gr.Markdown("## Example Images")
gr.Examples(
examples=[
["./examples/image_000003.png", "With Noise"],
["./examples/image_000005.png", "Without Noise"]
],
inputs=[input_image, model_choice],
outputs=[output_mask, output_overlay],
fn=predict,
cache_examples=False,
label="Click to load and run a sample image with predefined model weights.",
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860
)