muhammadhamza-stack
change mask color
4a76e27
import gradio as gr
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from PIL import Image
import torch
import numpy as np
import os
# --- Documentation Strings ---
USAGE_GUIDELINES = """
## 1. Quick Start Guide: Generating a Segmentation Mask
This tool uses a powerful AI model (SegFormer) to analyze an image and identify different object boundaries, creating a map known as a segmentation mask.
1. **Upload**: Click the 'Upload Blood Smear Image' box and select your image (JPG or PNG).
2. **Test**: Optionally, click one of the 'Example Images' below to load sample data instantly.
3. **Run**: Click the **"Submit"** button.
4. **Review**: The 'Predicted Grayscale Mask' will appear, showing the boundaries the model detected.
"""
INPUT_EXPLANATION = """
## 2. Expected Inputs
| Input Field | Purpose | Requirement |
| :--- | :--- | :--- |
| **Upload Image** | The image containing the sample you want to analyze. | Must be a single image file (JPG, PNG). The input image is automatically resized to 512x512 pixels before processing. |
"""
OUTPUT_EXPLANATION = """
## 3. Expected Outputs (The Segmentation Mask)
The output is the **Predicted Grayscale Mask**. This is a map where every pixel is assigned a color based on the object class the model identified.
* **Grayscale Representation:** Unlike color segmentation maps, this output uses different *shades of gray* to represent the 150 possible object categories (classes) the underlying model was trained on.
* **Purpose:** The mask visually separates the detected boundaries (e.g., cell outlines, background, debris) from one another.
* **Clarity:** The mask is automatically enlarged by **300% (3x)** to make the details easier to see.
### Sample Data for Testing
You can quickly test the application using the following provided example files. Click on the thumbnails (Sample blood smear image) below to upload the image automatically:
"""
# --------------------
# Core Pipeline Functions
# --------------------
# Load pretrained model
# Model: SegFormer B0 fine-tuned on ADE20K (150 classes)
processor = SegformerImageProcessor(do_reduce_labels=False)
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model.eval()
# Prediction function
def segment_image(input_image):
if input_image is None:
gr.Warning("Please upload an image or select an example.")
return None
inputs = processor(images=input_image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# 1. Get the predicted class index for each pixel
pred_mask = torch.argmax(logits, dim=1)[0].cpu().numpy()
# 2. Normalize the mask to create a visible grayscale image (0-255 range)
num_classes = logits.shape[1] # Should be 150 for ADE20K
# normalized_mask = (pred_mask * (255 // num_classes)).astype(np.uint8)
#change the segmented mask color to white (255) and the background to black (0)
normalized_mask = np.where(pred_mask > 0, 255, 0).astype(np.uint8)
output_image = Image.fromarray(normalized_mask)
# 3. Bigger mask (3x) for better visualization
scale_factor = 3
new_size = (output_image.width * scale_factor, output_image.height * scale_factor)
bigger_output = output_image.resize(new_size, resample=Image.NEAREST)
return bigger_output
# --------------------
# Gradio UI
# --------------------
with gr.Blocks(title="Semantic Segmentation Tool") as demo:
gr.Markdown("<h1 style='text-align: center;'> Malaria Blood Smear Segmentation (SegFormer) </h1>")
gr.Markdown("Tool for analyzing image boundaries using a general purpose semantic segmentation model.")
# 1. Guidelines Accordion
with gr.Accordion(" Tips & Guidelines", open=False):
gr.Markdown(USAGE_GUIDELINES)
gr.Markdown("---")
gr.Markdown(INPUT_EXPLANATION)
gr.Markdown("---")
gr.Markdown(OUTPUT_EXPLANATION)
# 2. Interface Definition (Embedded)
gr.Markdown("## Step 1: Upload the image")
with gr.Row():
with gr.Column(scale=1):
# Define Input component directly inside the column (No .render() needed)
input_image = gr.Image(type="pil", label="Step 1: Upload Blood Smear Image")
gr.Markdown("## Step 2: Click Submit for Segmentation")
with gr.Row():
submit_button = gr.Button("Submit for Segmentation", variant="primary")
gr.Markdown("## Output")
with gr.Row(scale=2):
# Define Output component directly inside the column (No .render() needed)
output_image = gr.Image(type="pil", label="Step 3: Predicted Grayscale Mask")
# 3. Examples Section
gr.Markdown("---")
gr.Markdown("## Example Images")
gr.Examples(
examples=["data/1.png", "data/2.png", "data/3.png", "data/211.png"],
inputs=[input_image],
outputs=[output_image],
fn=segment_image,
cache_examples=False,
label="Click to load and run a sample image",
)
# Event Handler
submit_button.click(
fn=segment_image,
inputs=input_image,
outputs=output_image
)
if __name__ == "__main__":
demo.launch()