Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor | |
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| import os | |
| # --- Documentation Strings --- | |
| USAGE_GUIDELINES = """ | |
| ## 1. Quick Start Guide: Generating a Segmentation Mask | |
| This tool uses a powerful AI model (SegFormer) to analyze an image and identify different object boundaries, creating a map known as a segmentation mask. | |
| 1. **Upload**: Click the 'Upload Blood Smear Image' box and select your image (JPG or PNG). | |
| 2. **Test**: Optionally, click one of the 'Example Images' below to load sample data instantly. | |
| 3. **Run**: Click the **"Submit"** button. | |
| 4. **Review**: The 'Predicted Grayscale Mask' will appear, showing the boundaries the model detected. | |
| """ | |
| INPUT_EXPLANATION = """ | |
| ## 2. Expected Inputs | |
| | Input Field | Purpose | Requirement | | |
| | :--- | :--- | :--- | | |
| | **Upload Image** | The image containing the sample you want to analyze. | Must be a single image file (JPG, PNG). The input image is automatically resized to 512x512 pixels before processing. | | |
| """ | |
| OUTPUT_EXPLANATION = """ | |
| ## 3. Expected Outputs (The Segmentation Mask) | |
| The output is the **Predicted Grayscale Mask**. This is a map where every pixel is assigned a color based on the object class the model identified. | |
| * **Grayscale Representation:** Unlike color segmentation maps, this output uses different *shades of gray* to represent the 150 possible object categories (classes) the underlying model was trained on. | |
| * **Purpose:** The mask visually separates the detected boundaries (e.g., cell outlines, background, debris) from one another. | |
| * **Clarity:** The mask is automatically enlarged by **300% (3x)** to make the details easier to see. | |
| ### Sample Data for Testing | |
| You can quickly test the application using the following provided example files. Click on the thumbnails (Sample blood smear image) below to upload the image automatically: | |
| """ | |
| # -------------------- | |
| # Core Pipeline Functions | |
| # -------------------- | |
| # Load pretrained model | |
| # Model: SegFormer B0 fine-tuned on ADE20K (150 classes) | |
| processor = SegformerImageProcessor(do_reduce_labels=False) | |
| model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") | |
| model.eval() | |
| # Prediction function | |
| def segment_image(input_image): | |
| if input_image is None: | |
| gr.Warning("Please upload an image or select an example.") | |
| return None | |
| inputs = processor(images=input_image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # 1. Get the predicted class index for each pixel | |
| pred_mask = torch.argmax(logits, dim=1)[0].cpu().numpy() | |
| # 2. Normalize the mask to create a visible grayscale image (0-255 range) | |
| num_classes = logits.shape[1] # Should be 150 for ADE20K | |
| # normalized_mask = (pred_mask * (255 // num_classes)).astype(np.uint8) | |
| #change the segmented mask color to white (255) and the background to black (0) | |
| normalized_mask = np.where(pred_mask > 0, 255, 0).astype(np.uint8) | |
| output_image = Image.fromarray(normalized_mask) | |
| # 3. Bigger mask (3x) for better visualization | |
| scale_factor = 3 | |
| new_size = (output_image.width * scale_factor, output_image.height * scale_factor) | |
| bigger_output = output_image.resize(new_size, resample=Image.NEAREST) | |
| return bigger_output | |
| # -------------------- | |
| # Gradio UI | |
| # -------------------- | |
| with gr.Blocks(title="Semantic Segmentation Tool") as demo: | |
| gr.Markdown("<h1 style='text-align: center;'> Malaria Blood Smear Segmentation (SegFormer) </h1>") | |
| gr.Markdown("Tool for analyzing image boundaries using a general purpose semantic segmentation model.") | |
| # 1. Guidelines Accordion | |
| with gr.Accordion(" Tips & Guidelines", open=False): | |
| gr.Markdown(USAGE_GUIDELINES) | |
| gr.Markdown("---") | |
| gr.Markdown(INPUT_EXPLANATION) | |
| gr.Markdown("---") | |
| gr.Markdown(OUTPUT_EXPLANATION) | |
| # 2. Interface Definition (Embedded) | |
| gr.Markdown("## Step 1: Upload the image") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Define Input component directly inside the column (No .render() needed) | |
| input_image = gr.Image(type="pil", label="Step 1: Upload Blood Smear Image") | |
| gr.Markdown("## Step 2: Click Submit for Segmentation") | |
| with gr.Row(): | |
| submit_button = gr.Button("Submit for Segmentation", variant="primary") | |
| gr.Markdown("## Output") | |
| with gr.Row(scale=2): | |
| # Define Output component directly inside the column (No .render() needed) | |
| output_image = gr.Image(type="pil", label="Step 3: Predicted Grayscale Mask") | |
| # 3. Examples Section | |
| gr.Markdown("---") | |
| gr.Markdown("## Example Images") | |
| gr.Examples( | |
| examples=["data/1.png", "data/2.png", "data/3.png", "data/211.png"], | |
| inputs=[input_image], | |
| outputs=[output_image], | |
| fn=segment_image, | |
| cache_examples=False, | |
| label="Click to load and run a sample image", | |
| ) | |
| # Event Handler | |
| submit_button.click( | |
| fn=segment_image, | |
| inputs=input_image, | |
| outputs=output_image | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |