import gradio as gr import torch from unet import EnhancedUNet import numpy as np from PIL import Image import io import math # --- Documentation Strings --- USAGE_GUIDELINES = """ ## 1. Quick Start Guide: Generating a Segmentation Mask This tool analyzes your uploaded MoS2 image, breaking it down into small patches, classifying those patches using a U-Net model, and stitching the results back into a full segmentation mask. 1. **Upload Image**: Click the image box and upload your MoS2 micrograph (PNG or JPG). 2. **Select Model**: Choose the appropriate model weight from the dropdown (see Section 3 for differences). 3. **Run**: Click the **"Submit"** button. 4. **Review**: Two outputs will appear: the raw grayscale **Segmentation Mask** and the color **Overlay** (which combines the mask with the original image). """ INPUT_EXPLANATION = """ ## 2. Input Requirements | Input Field | Purpose | Requirement | | :--- | :--- | :--- | | **Input Image** | The MoS2 micrograph to be segmented. | Must be a single image file (JPG, PNG). The system automatically converts the image to **grayscale (1 channel)** before processing. | | **Model Choice** | Selects the specific set of U-Net weights to use for inference. | Required choice among the three available options (see Model Guide below). | ### Technical Note: Patching This application uses a patch-based approach: 1. The uploaded image is broken into non-overlapping **256x256 pixel patches**. 2. Each patch is analyzed individually by the U-Net. 3. The predicted patches are **stitched back together** to form the final segmentation map. This technique allows high-resolution images to be processed efficiently by a model trained on smaller inputs. """ MODEL_GUIDANCE = """ ## 3. Model Selection Guidance (Without Noise vs. With Noise) The application provides three distinct model weights, reflecting different training strategies: | Model Option | Training Strategy | Recommended Use Case | | :--- | :--- | :--- | | **Without Noise** | Trained on clean, standard dataset images. | Use for high-quality, clear micrographs. Expect highly precise boundaries where the data matches the training set. | | **With Noise** | Trained with artificial noise augmentation (e.g., Gaussian, Salt-and-Pepper). | Use for real-world images that may contain artifacts, varying light, or complex background interference. Provides better **generalization** and robustness. | | **With Noise V2** | An updated version of the 'With Noise' model, potentially offering improved boundary definition or accuracy. | Recommended as the default choice for robust, high-performance segmentation across varied image quality. | """ OUTPUT_INTERPRETATION = """ ## 4. Expected Outputs The output provides two results: the raw segmentation mask and a visual overlay. The model classifies every pixel into one of **4 distinct classes (0-3)**, likely corresponding to different layers or regions of the MoS2 structure. ### A. Segmentation Mask (Grayscale) This image shows the raw classification output. The class index (0, 1, 2, or 3) is mapped to a grayscale intensity. * Class 0 is represented by **Black**. * Higher classes (1, 2, 3) are represented by progressively **lighter shades of gray**. ### B. Overlay (Colored) This is the most straightforward visual output, blending the original image with the color-coded mask using a default transparency (alpha). | Color | Underlying Class Index | Possible MoS2 Region | | :--- | :--- | :--- | | **Black** (0, 0, 0) | Class 0 | Unlabeled Region / Background | | **Red** (255, 0, 0) | Class 1 | Region A (e.g., Monolayer) | | **Green** (0, 255, 0) | Class 2 | Region B (e.g., Bilayer) | | **Blue** (0, 0, 255) | Class 3 | Region C (e.g., Bulk/Debris) | """ # -------------------- # Core Pipeline Functions (Kept AS IS) # -------------------- def initialize_model(model_path): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = EnhancedUNet(n_channels=1, n_classes=4).to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() return model, device def extract_patches(image, patch_size=256): """Extract patches from the input image.""" width, height = image.size patches = [] positions = [] # Calculate number of patches in each dimension n_cols = math.ceil(width / patch_size) n_rows = math.ceil(height / patch_size) # Pad image if necessary padded_width = n_cols * patch_size padded_height = n_rows * patch_size padded_image = Image.new('L', (padded_width, padded_height), 0) padded_image.paste(image, (0, 0)) # Extract patches for i in range(n_rows): for j in range(n_cols): left = j * patch_size top = i * patch_size right = left + patch_size bottom = top + patch_size patch = padded_image.crop((left, top, right, bottom)) patches.append(patch) positions.append((left, top, right, bottom)) return patches, positions, (padded_width, padded_height), (width, height) def stitch_patches(patches, positions, padded_size, original_size, n_classes=4): """Stitch patches back together into a complete mask.""" full_mask = np.zeros((padded_size[1], padded_size[0]), dtype=np.uint8) for patch, (left, top, right, bottom) in zip(patches, positions): full_mask[top:bottom, left:right] = patch # Crop back to original size full_mask = full_mask[:original_size[1], :original_size[0]] return full_mask def process_patch(patch, device): # Convert to grayscale if it's not already patch_gray = patch.convert("L") # Convert to numpy array and normalize patch_np = np.array(patch_gray, dtype=np.float32) / 255.0 # Add batch and channel dimensions patch_tensor = torch.from_numpy(patch_np).float().unsqueeze(0).unsqueeze(0) return patch_tensor.to(device) def create_overlay(original_image, mask, alpha=0.5): # Define colors for the 4 classes: Black, Red, Green, Blue colors = [(0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255)] mask_rgb = np.zeros((*mask.shape, 3), dtype=np.uint8) for i, color in enumerate(colors): mask_rgb[mask == i] = color # Resize original image to match mask size original_image = original_image.resize((mask.shape[1], mask.shape[0])) original_array = np.array(original_image.convert("RGB")) # Create overlay overlay = (alpha * mask_rgb + (1 - alpha) * original_array).astype(np.uint8) return Image.fromarray(overlay) # Initialization function required for the interface handler def predict(input_image, model_choice): if input_image is None: gr.Warning("Please upload an image or select an example.") return None, None model = models[model_choice] patch_size = 256 # Extract patches patches, positions, padded_size, original_size = extract_patches(input_image, patch_size) # Process each patch predicted_patches = [] for patch in patches: # Process patch patch_tensor = process_patch(patch, device) # Perform inference with torch.no_grad(): output = model(patch_tensor) # Get prediction mask for patch patch_mask = torch.argmax(output, dim=1).cpu().numpy()[0] predicted_patches.append(patch_mask) # Stitch patches back together full_mask = stitch_patches(predicted_patches, positions, padded_size, original_size) # Create mask image # Scale for better visibility (255 / 4 classes * class_index) mask_image = Image.fromarray((full_mask * (255 // 4)).astype(np.uint8)) # Create overlay image overlay_image = create_overlay(input_image, full_mask) return mask_image, overlay_image # -------------------- # Model Initialization # -------------------- w_noise_model_path = "./models/best_model_w_noise.pth" wo_noise_model_path = "./models/best_model_wo_noise.pth" w_noise_model_v2_path = "./models/best_model_w_noise_v2.pth" # Initialize models (assuming files exist) try: w_noise_model, device = initialize_model(w_noise_model_path) wo_noise_model, device = initialize_model(wo_noise_model_path) w_noise_model_v2, device = initialize_model(w_noise_model_v2_path) except FileNotFoundError as e: print(f"Warning: Model files not found. Using dummy initialization. Error: {e}") # Fallback dummy models for interface setup if files are missing device = torch.device("cpu") w_noise_model = EnhancedUNet(n_channels=1, n_classes=4).to(device) wo_noise_model = EnhancedUNet(n_channels=1, n_classes=4).to(device) w_noise_model_v2 = EnhancedUNet(n_channels=1, n_classes=4).to(device) models = { "Without Noise": wo_noise_model, "With Noise": w_noise_model, "With Noise V2": w_noise_model_v2 } # -------------------- # Gradio UI (Blocks Structure for Guidelines) # -------------------- with gr.Blocks(title="MoS2 Image Segmentation") as demo: gr.Markdown("

MoS2 Micrograph Segmentation (U-Net Patch-Based)

") gr.Markdown("Tool for analyzing and segmenting layered Molybdenum Disulfide (MoS2) structures into 4 defined regions.") # 1. Guidelines Accordion with gr.Accordion("Tips, Guidelines, and Model Selection", open=False): gr.Markdown(USAGE_GUIDELINES) gr.Markdown("---") gr.Markdown(INPUT_EXPLANATION) gr.Markdown("---") gr.Markdown(MODEL_GUIDANCE) gr.Markdown("---") gr.Markdown(OUTPUT_INTERPRETATION) gr.Markdown("## Segmentation Input and Configuration") with gr.Row(): # Input Column with gr.Column(scale=1): gr.Markdown("## Step 1: Upload a MoS2 Micrograph image ") input_image = gr.Image(type="pil", label=" MoS2 Micrograph") gr.Markdown("## Step 2: Select Model Weights ") model_choice = gr.Dropdown( choices=["Without Noise", "With Noise", "With Noise V2"], value="With Noise V2", label=" Model Weights" ) gr.Markdown("## Step 3: Click Submit for Sugmentation ") submit_button = gr.Button("Submit for Segmentation", variant="primary") gr.Markdown("## Segmentation Outputs") # Output Row with gr.Row(): output_mask = gr.Image(type="pil", label="Step 3: Segmentation Mask (Grayscale)") output_overlay = gr.Image(type="pil", label="Step 4: Segmentation Overlay (Color-Coded)") # Event Handler submit_button.click( fn=predict, inputs=[input_image, model_choice], outputs=[output_mask, output_overlay] ) # Examples Section (Must come after component definition) gr.Markdown("---") gr.Markdown("## Example Images") gr.Examples( examples=[ ["./examples/image_000003.png", "With Noise"], ["./examples/image_000005.png", "Without Noise"] ], inputs=[input_image, model_choice], outputs=[output_mask, output_overlay], fn=predict, cache_examples=False, label="Click to load and run a sample image with predefined model weights.", ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860 )