Spaces:
Paused
Paused
| import gradio as gr | |
| import torch | |
| from unet import EnhancedUNet | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import math | |
| # --- Documentation Strings --- | |
| USAGE_GUIDELINES = """ | |
| ## 1. Quick Start Guide: Generating a Segmentation Mask | |
| This tool analyzes your uploaded MoS2 image, breaking it down into small patches, classifying those patches using a U-Net model, and stitching the results back into a full segmentation mask. | |
| 1. **Upload Image**: Click the image box and upload your MoS2 micrograph (PNG or JPG). | |
| 2. **Select Model**: Choose the appropriate model weight from the dropdown (see Section 3 for differences). | |
| 3. **Run**: Click the **"Submit"** button. | |
| 4. **Review**: Two outputs will appear: the raw grayscale **Segmentation Mask** and the color **Overlay** (which combines the mask with the original image). | |
| """ | |
| INPUT_EXPLANATION = """ | |
| ## 2. Input Requirements | |
| | Input Field | Purpose | Requirement | | |
| | :--- | :--- | :--- | | |
| | **Input Image** | The MoS2 micrograph to be segmented. | Must be a single image file (JPG, PNG). The system automatically converts the image to **grayscale (1 channel)** before processing. | | |
| | **Model Choice** | Selects the specific set of U-Net weights to use for inference. | Required choice among the three available options (see Model Guide below). | | |
| ### Technical Note: Patching | |
| This application uses a patch-based approach: | |
| 1. The uploaded image is broken into non-overlapping **256x256 pixel patches**. | |
| 2. Each patch is analyzed individually by the U-Net. | |
| 3. The predicted patches are **stitched back together** to form the final segmentation map. This technique allows high-resolution images to be processed efficiently by a model trained on smaller inputs. | |
| """ | |
| MODEL_GUIDANCE = """ | |
| ## 3. Model Selection Guidance (Without Noise vs. With Noise) | |
| The application provides three distinct model weights, reflecting different training strategies: | |
| | Model Option | Training Strategy | Recommended Use Case | | |
| | :--- | :--- | :--- | | |
| | **Without Noise** | Trained on clean, standard dataset images. | Use for high-quality, clear micrographs. Expect highly precise boundaries where the data matches the training set. | | |
| | **With Noise** | Trained with artificial noise augmentation (e.g., Gaussian, Salt-and-Pepper). | Use for real-world images that may contain artifacts, varying light, or complex background interference. Provides better **generalization** and robustness. | | |
| | **With Noise V2** | An updated version of the 'With Noise' model, potentially offering improved boundary definition or accuracy. | Recommended as the default choice for robust, high-performance segmentation across varied image quality. | | |
| """ | |
| OUTPUT_INTERPRETATION = """ | |
| ## 4. Expected Outputs | |
| The output provides two results: the raw segmentation mask and a visual overlay. The model classifies every pixel into one of **4 distinct classes (0-3)**, likely corresponding to different layers or regions of the MoS2 structure. | |
| ### A. Segmentation Mask (Grayscale) | |
| This image shows the raw classification output. The class index (0, 1, 2, or 3) is mapped to a grayscale intensity. | |
| * Class 0 is represented by **Black**. | |
| * Higher classes (1, 2, 3) are represented by progressively **lighter shades of gray**. | |
| ### B. Overlay (Colored) | |
| This is the most straightforward visual output, blending the original image with the color-coded mask using a default transparency (alpha). | |
| | Color | Underlying Class Index | Possible MoS2 Region | | |
| | :--- | :--- | :--- | | |
| | **Black** (0, 0, 0) | Class 0 | Unlabeled Region / Background | | |
| | **Red** (255, 0, 0) | Class 1 | Region A (e.g., Monolayer) | | |
| | **Green** (0, 255, 0) | Class 2 | Region B (e.g., Bilayer) | | |
| | **Blue** (0, 0, 255) | Class 3 | Region C (e.g., Bulk/Debris) | | |
| """ | |
| # -------------------- | |
| # Core Pipeline Functions (Kept AS IS) | |
| # -------------------- | |
| def initialize_model(model_path): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = EnhancedUNet(n_channels=1, n_classes=4).to(device) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.eval() | |
| return model, device | |
| def extract_patches(image, patch_size=256): | |
| """Extract patches from the input image.""" | |
| width, height = image.size | |
| patches = [] | |
| positions = [] | |
| # Calculate number of patches in each dimension | |
| n_cols = math.ceil(width / patch_size) | |
| n_rows = math.ceil(height / patch_size) | |
| # Pad image if necessary | |
| padded_width = n_cols * patch_size | |
| padded_height = n_rows * patch_size | |
| padded_image = Image.new('L', (padded_width, padded_height), 0) | |
| padded_image.paste(image, (0, 0)) | |
| # Extract patches | |
| for i in range(n_rows): | |
| for j in range(n_cols): | |
| left = j * patch_size | |
| top = i * patch_size | |
| right = left + patch_size | |
| bottom = top + patch_size | |
| patch = padded_image.crop((left, top, right, bottom)) | |
| patches.append(patch) | |
| positions.append((left, top, right, bottom)) | |
| return patches, positions, (padded_width, padded_height), (width, height) | |
| def stitch_patches(patches, positions, padded_size, original_size, n_classes=4): | |
| """Stitch patches back together into a complete mask.""" | |
| full_mask = np.zeros((padded_size[1], padded_size[0]), dtype=np.uint8) | |
| for patch, (left, top, right, bottom) in zip(patches, positions): | |
| full_mask[top:bottom, left:right] = patch | |
| # Crop back to original size | |
| full_mask = full_mask[:original_size[1], :original_size[0]] | |
| return full_mask | |
| def process_patch(patch, device): | |
| # Convert to grayscale if it's not already | |
| patch_gray = patch.convert("L") | |
| # Convert to numpy array and normalize | |
| patch_np = np.array(patch_gray, dtype=np.float32) / 255.0 | |
| # Add batch and channel dimensions | |
| patch_tensor = torch.from_numpy(patch_np).float().unsqueeze(0).unsqueeze(0) | |
| return patch_tensor.to(device) | |
| def create_overlay(original_image, mask, alpha=0.5): | |
| # Define colors for the 4 classes: Black, Red, Green, Blue | |
| colors = [(0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255)] | |
| mask_rgb = np.zeros((*mask.shape, 3), dtype=np.uint8) | |
| for i, color in enumerate(colors): | |
| mask_rgb[mask == i] = color | |
| # Resize original image to match mask size | |
| original_image = original_image.resize((mask.shape[1], mask.shape[0])) | |
| original_array = np.array(original_image.convert("RGB")) | |
| # Create overlay | |
| overlay = (alpha * mask_rgb + (1 - alpha) * original_array).astype(np.uint8) | |
| return Image.fromarray(overlay) | |
| # Initialization function required for the interface handler | |
| def predict(input_image, model_choice): | |
| if input_image is None: | |
| gr.Warning("Please upload an image or select an example.") | |
| return None, None | |
| model = models[model_choice] | |
| patch_size = 256 | |
| # Extract patches | |
| patches, positions, padded_size, original_size = extract_patches(input_image, patch_size) | |
| # Process each patch | |
| predicted_patches = [] | |
| for patch in patches: | |
| # Process patch | |
| patch_tensor = process_patch(patch, device) | |
| # Perform inference | |
| with torch.no_grad(): | |
| output = model(patch_tensor) | |
| # Get prediction mask for patch | |
| patch_mask = torch.argmax(output, dim=1).cpu().numpy()[0] | |
| predicted_patches.append(patch_mask) | |
| # Stitch patches back together | |
| full_mask = stitch_patches(predicted_patches, positions, padded_size, original_size) | |
| # Create mask image | |
| # Scale for better visibility (255 / 4 classes * class_index) | |
| mask_image = Image.fromarray((full_mask * (255 // 4)).astype(np.uint8)) | |
| # Create overlay image | |
| overlay_image = create_overlay(input_image, full_mask) | |
| return mask_image, overlay_image | |
| # -------------------- | |
| # Model Initialization | |
| # -------------------- | |
| w_noise_model_path = "./models/best_model_w_noise.pth" | |
| wo_noise_model_path = "./models/best_model_wo_noise.pth" | |
| w_noise_model_v2_path = "./models/best_model_w_noise_v2.pth" | |
| # Initialize models (assuming files exist) | |
| try: | |
| w_noise_model, device = initialize_model(w_noise_model_path) | |
| wo_noise_model, device = initialize_model(wo_noise_model_path) | |
| w_noise_model_v2, device = initialize_model(w_noise_model_v2_path) | |
| except FileNotFoundError as e: | |
| print(f"Warning: Model files not found. Using dummy initialization. Error: {e}") | |
| # Fallback dummy models for interface setup if files are missing | |
| device = torch.device("cpu") | |
| w_noise_model = EnhancedUNet(n_channels=1, n_classes=4).to(device) | |
| wo_noise_model = EnhancedUNet(n_channels=1, n_classes=4).to(device) | |
| w_noise_model_v2 = EnhancedUNet(n_channels=1, n_classes=4).to(device) | |
| models = { | |
| "Without Noise": wo_noise_model, | |
| "With Noise": w_noise_model, | |
| "With Noise V2": w_noise_model_v2 | |
| } | |
| # -------------------- | |
| # Gradio UI (Blocks Structure for Guidelines) | |
| # -------------------- | |
| with gr.Blocks(title="MoS2 Image Segmentation") as demo: | |
| gr.Markdown("<h1 style='text-align: center;'> MoS2 Micrograph Segmentation (U-Net Patch-Based) </h1>") | |
| gr.Markdown("Tool for analyzing and segmenting layered Molybdenum Disulfide (MoS2) structures into 4 defined regions.") | |
| # 1. Guidelines Accordion | |
| with gr.Accordion("Tips, Guidelines, and Model Selection", open=False): | |
| gr.Markdown(USAGE_GUIDELINES) | |
| gr.Markdown("---") | |
| gr.Markdown(INPUT_EXPLANATION) | |
| gr.Markdown("---") | |
| gr.Markdown(MODEL_GUIDANCE) | |
| gr.Markdown("---") | |
| gr.Markdown(OUTPUT_INTERPRETATION) | |
| gr.Markdown("## Segmentation Input and Configuration") | |
| with gr.Row(): | |
| # Input Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Step 1: Upload a MoS2 Micrograph image ") | |
| input_image = gr.Image(type="pil", label=" MoS2 Micrograph") | |
| gr.Markdown("## Step 2: Select Model Weights ") | |
| model_choice = gr.Dropdown( | |
| choices=["Without Noise", "With Noise", "With Noise V2"], | |
| value="With Noise V2", | |
| label=" Model Weights" | |
| ) | |
| gr.Markdown("## Step 3: Click Submit for Sugmentation ") | |
| submit_button = gr.Button("Submit for Segmentation", variant="primary") | |
| gr.Markdown("## Segmentation Outputs") | |
| # Output Row | |
| with gr.Row(): | |
| output_mask = gr.Image(type="pil", label="Step 3: Segmentation Mask (Grayscale)") | |
| output_overlay = gr.Image(type="pil", label="Step 4: Segmentation Overlay (Color-Coded)") | |
| # Event Handler | |
| submit_button.click( | |
| fn=predict, | |
| inputs=[input_image, model_choice], | |
| outputs=[output_mask, output_overlay] | |
| ) | |
| # Examples Section (Must come after component definition) | |
| gr.Markdown("---") | |
| gr.Markdown("## Example Images") | |
| gr.Examples( | |
| examples=[ | |
| ["./examples/image_000003.png", "With Noise"], | |
| ["./examples/image_000005.png", "Without Noise"] | |
| ], | |
| inputs=[input_image, model_choice], | |
| outputs=[output_mask, output_overlay], | |
| fn=predict, | |
| cache_examples=False, | |
| label="Click to load and run a sample image with predefined model weights.", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) |