Spaces:
Sleeping
Sleeping
File size: 8,313 Bytes
d0c072e f23f0f3 d0c072e f23f0f3 d0c072e f23f0f3 31126d3 f23f0f3 d0c072e f23f0f3 d0c072e f23f0f3 d0c072e f23f0f3 d0c072e f23f0f3 d0c072e f23f0f3 d0c072e f23f0f3 d0c072e f23f0f3 d0c072e f23f0f3 d0c072e f23f0f3 d3519ad f23f0f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
from typing import Tuple
from ultralytics import YOLO
from ultralytics.engine.results import Boxes
from ultralytics.utils.plotting import Annotator
import gradio as gr
import os
# --- Model Loading ---
try:
cell_detector = YOLO("./weights/yolo_uninfected_cells.pt")
yolo_detector = YOLO("./weights/yolo_infected_cells.pt")
redetr_detector = YOLO("./weights/redetr_infected_cells.pt")
except Exception as e:
print(f"Warning: Model loading failed. Ensure weights files are in ./weights/ directory. Error: {e}")
# Define placeholder models if real models fail to load (for UI development)
class DummyYOLO:
def predict(self, image, conf=0.5):
# Return dummy results structure
class DummyBoxes:
xyxy = []
class DummyResult:
boxes = DummyBoxes()
return [DummyResult()]
cell_detector = DummyYOLO()
yolo_detector = DummyYOLO()
redetr_detector = DummyYOLO()
models = {"Yolo V11": yolo_detector, "Real Time Detection Transformer": redetr_detector}
# --- Documentation Strings ---
USAGE_GUIDELINES = """
## 1. Quick Start Guide: Cell Detection and Counting
This application uses two specialized Artificial Intelligence models to analyze a blood smear image, simultaneously detecting both healthy and potentially infected (unhealthy) cells.
1. **Upload**: Upload a clear blood smear image (JPG or PNG) using the 'Input Image' box.
2. **Select Model**: Choose between the two detection models: `Yolo V11` (often fast and accurate for common objects) or `Real Time Detection Transformer`.
3. **Adjust Confidence**: Use the slider to set the **Confidence Threshold**. (A higher value means the model must be more certain of a detection.)
4. **Run**: Click the **"Submit"** button.
5. **Review**: The output image will show bounding boxes around detected cells (colors based on model configuration), and the counts will be displayed below.
### Key Requirement:
* The system uses **two independent models**: one strictly for **Healthy Cells**, and one (the selected model) for **Infected Cells**.
"""
INPUT_EXPLANATION = """
## 2. Expected Inputs
| Parameter | Purpose | Range/Options | Guidance for Non-Tech Users |
| :--- | :--- | :--- | :--- |
| **Input Image** | The microscopic blood smear image to be analyzed. | JPG, PNG format. | Ensure the image is clear and focused. |
| **Model Selection** | Chooses the AI architecture used for detecting **Infected Cells**. | Yolo V11, Real Time Detection Transformer | Start with the default (`Yolo V11`) unless specific performance is required. |
| **Confidence Threshold** | The minimum probability required for a detection box to be shown. | 0.01 to 1.00 | Setting this too low (e.g., 0.1) may show many false positives. Setting it too high (e.g., 0.9) may miss real cells. Start around 0.5. |
"""
OUTPUT_EXPLANATION = """
## 3. Expected Outputs
| Output Field | Description | Interpretation |
| :--- | :--- | :--- |
| **Output Image** | The input image with colored bounding boxes drawn around every detected cell. | Visually confirms the location and classification of each cell. |
| **Healthy Cells Count** | The total number of cells detected by the dedicated *uninfected* cell model. | Provides a baseline count of normal cells. |
| **Infected Cells Count** | The total number of cells detected by the *selected* model (Yolo V11 or RT DETR). | This represents the count of potentially cancerous/abnormal cells. |
"""
# --- Example Data Setup ---
SAMPLE_EXAMPLES = [
["./blood_smear_1.jpg", "Yolo V11", 0.5],
["./blood_smear_2.jpg", "Real Time Detection Transformer", 0.45],
]
# ----------------- Core Inference Function -----------------
def inference(image, model, conf) -> Tuple[str, str, str]:
if image is None:
gr.Error("Please upload an image.")
return None, "0", "0"
if model not in models:
gr.Error(f"Selected model '{model}' is not available.")
return None, "0", "0"
bboxes = []
labels = []
# Use lists to store counts that will be incremented
healthy_cell_count_list = [0]
unhealthy_cell_count_list = [0]
# 1. Healthy Cell Detection (Fixed model and fixed confidence 0.4)
cells_results = cell_detector.predict(image, conf=0.4)
for cell_result in cells_results:
boxes: Boxes = cell_result.boxes
healthy_cells_bboxes = boxes.xyxy.tolist()
healthy_cell_count_list[0] += len(healthy_cells_bboxes)
bboxes.extend(healthy_cells_bboxes)
# Note: YOLO classes start at 0. Here we use custom labels 'healthy'
labels.extend(["healthy"] * len(healthy_cells_bboxes))
# 2. Infected Cell Detection (Selected model and user-defined confidence)
selected_model_results = models[model].predict(image, conf=conf)
for res in selected_model_results:
boxes: Boxes = res.boxes
unhealthy_cells_bboxes = boxes.xyxy.tolist()
unhealthy_cell_count_list[0] += len(unhealthy_cells_bboxes)
bboxes.extend(unhealthy_cells_bboxes)
# Note: Use 'unhealthy' label for the selected model's output
labels.extend(["unhealthy"] * len(unhealthy_cells_bboxes))
# 3. Annotation
annotator = Annotator(image, font_size=30, line_width=4, pil=True) # Increased font/width for visibility
# Define colors based on label
color_map = {"healthy": (0, 255, 0), "unhealthy": (255, 0, 0)} # Green for healthy, Red for unhealthy
for box, label in zip(bboxes, labels):
# Annotator expects a list of 4 float coords and an optional label string
annotator.box_label(box, label, color=color_map.get(label, (255, 255, 255)))
img = annotator.result()
# Return results as strings for the Textbox components
return (img, str(healthy_cell_count_list[0]), str(unhealthy_cell_count_list[0]))
# ----------------- Gradio Interface (Blocks) -----------------
with gr.Blocks(title="Blood Cell Detection") as ifer:
gr.Markdown("<h1 style='text-align: center;'> Blood Cell Cancer Detection and Counting </h1>")
gr.Markdown("Uses specialized object detection models to count healthy and infected cells in blood smear images.")
# 1. Documentation
with gr.Accordion(" Tips & Guidelines ", open=False):
gr.Markdown(USAGE_GUIDELINES)
gr.Markdown("---")
gr.Markdown(INPUT_EXPLANATION)
gr.Markdown("---")
gr.Markdown(OUTPUT_EXPLANATION)
# 2. Interface Inputs
with gr.Row():
with gr.Column():
gr.Markdown("## Step 1: Upload Image ")
image_input = gr.Image(label="Input Image", type="pil")
with gr.Column():
gr.Markdown("## Step 2: Set Parameters")
model_selection = gr.Dropdown(
label="Select Detection Model (for Infected Cells)",
choices=["Yolo V11", "Real Time Detection Transformer"],
multiselect=False,
value="Yolo V11"
)
conf_slider = gr.Slider(
minimum=0.01,
maximum=1,
value=0.5,
step=0.01,
label="Confidence Threshold (Min. certainty required)"
)
gr.Markdown("## Step 3: Click Analyze Image")
with gr.Row():
submit_button = gr.Button("Analyze Image", variant="primary")
# 3. Interface Outputs
gr.Markdown("## Results")
output_image = gr.Image(label="Output Image (Detected Cells)", type="numpy")
with gr.Row():
healthy_count = gr.Textbox(label="Healthy Cells Count")
unhealthy_count = gr.Textbox(label="Infected Cells Count")
# 4. Examples
gr.Markdown("---")
gr.Markdown("## Example Inputs")
gr.Examples(
examples=SAMPLE_EXAMPLES,
inputs=[image_input, model_selection, conf_slider],
outputs=[output_image, healthy_count, unhealthy_count],
fn=inference,
cache_examples=False,
label="Click a row to load the image and parameters"
)
# Event Handler
submit_button.click(
fn=inference,
inputs=[image_input, model_selection, conf_slider],
outputs=[output_image, healthy_count, unhealthy_count]
)
if __name__ == "__main__":
ifer.launch(share=True) |