File size: 3,439 Bytes
e69f86c
 
 
 
 
 
 
 
 
d49a4fc
45b4aee
d49a4fc
45b4aee
d49a4fc
 
 
45b4aee
 
 
 
 
d49a4fc
 
 
45b4aee
d49a4fc
e69f86c
 
45b4aee
d49a4fc
45b4aee
d49a4fc
 
 
 
45b4aee
e69f86c
 
45b4aee
e69f86c
 
45b4aee
 
 
e69f86c
 
 
 
45b4aee
e69f86c
45b4aee
e69f86c
45b4aee
e69f86c
 
 
45b4aee
e69f86c
 
d49a4fc
 
45b4aee
d49a4fc
45b4aee
 
e69f86c
 
 
 
 
45b4aee
e69f86c
8784ab7
d49a4fc
e69f86c
d49a4fc
e69f86c
45b4aee
e69f86c
45b4aee
e69f86c
 
d49a4fc
 
1281dba
d49a4fc
 
 
8784ab7
 
 
 
1281dba
8784ab7
 
 
e69f86c
ea2f994
 
e69f86c
8784ab7
45b4aee
 
e69f86c
45b4aee
e69f86c
 
45b4aee
e69f86c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from PIL import Image
import torch
import numpy as np

# --- Documentation Strings ---

USAGE_GUIDELINES = """
## 1. Quick Start Guide: HemaScan Pro (Binary Mask)

HemaScan Pro generates a high-contrast black & white segmentation mask.

1. Upload a blood smear image (JPG/PNG).
2. Click "Run Segmentation".
3. View the generated binary mask.
"""

INPUT_EXPLANATION = """
## 2. Expected Inputs

| Field | Requirement |
|-------|------------|
| Upload Image | JPG / PNG blood smear image |

✔ Automatically resized to 512×512.
"""

OUTPUT_EXPLANATION = """
## 3. Output Description (Black & White Mask)

• Background = White  
• Detected Regions = Black  
• Enlarged by 400% (4×) for clarity  
• Clean binary medical-style visualization
"""

# --------------------
# Model
# --------------------
processor = SegformerImageProcessor(do_reduce_labels=False)
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b0-finetuned-ade-512-512"
)
model.eval()

def segment_image(input_image):
    if input_image is None:
        gr.Warning("Please upload an image.")
        return None

    inputs = processor(images=input_image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    pred_mask = torch.argmax(logits, dim=1)[0].cpu().numpy()

    # Convert to binary mask (object vs background)
    binary_mask = np.where(pred_mask == 0, 255, 0).astype(np.uint8)

    output_image = Image.fromarray(binary_mask)

    # Scale 4x
    scale_factor = 4
    new_size = (output_image.width * scale_factor, output_image.height * scale_factor)
    return output_image.resize(new_size, resample=Image.NEAREST)

# --------------------
# UI
# --------------------
with gr.Blocks(title="Malaria Cell Segmentation Tool") as demo:
    gr.Markdown("<h1 style='text-align:center; background:linear-gradient(90deg,#4facfe,#00f2fe); color:white; padding:10px;'>HemaScan Pro - Binary Segmentation</h1>")

    with gr.Accordion(" Documentation", open=False):
        gr.Markdown(USAGE_GUIDELINES)
        gr.Markdown("---")
        gr.Markdown(INPUT_EXPLANATION)
        gr.Markdown("---")
        gr.Markdown(OUTPUT_EXPLANATION)

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("## Step 1: Upload Blood Smear Image")
            # Define Input component directly inside the column (No .render() needed)
            input_image = gr.Image(type="pil", label="Step 1: Upload Blood Smear Image", width=600, height=600) 
    
            gr.Markdown("## Step 2: Click Submit for Segmentation")
            with gr.Row():
                submit_button = gr.Button("Submit for Segmentation", variant="primary")
        with gr.Column(scale=1):
            gr.Markdown("## Output")
             # Define Output component directly inside the column (No .render() needed)
            output_image = gr.Image(type="pil", label="Step 3: Predicted Masks", width=600, height=600)


    gr.Markdown("---")
    gr.Markdown("## Example Images")
    gr.Examples(
        examples=["data/1.png", "data/2.png", "data/3.png"],
        inputs=input_image,
        outputs=output_image,
        fn=segment_image,
        cache_examples=False,
    )

    submit_button.click(segment_image, input_image, output_image)

if __name__ == "__main__":
    demo.launch()