| | from segment_anything import sam_model_registry, SamPredictor |
| | import torch |
| | import cv2 |
| | import numpy as np |
| | import gradio as gr |
| | import pandas as pd |
| | import matplotlib.pyplot as plt |
| |
|
| | |
| | sam_checkpoint = "sam_vit_h.pth" |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | model_type = "vit_h" |
| | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device) |
| | predictor = SamPredictor(sam) |
| |
|
| | def preprocess_image(image): |
| | """Convert image to RGB format for SAM.""" |
| | if len(image.shape) == 2: |
| | image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) |
| | return image |
| |
|
| | def detect_blood_cells(image): |
| | """Detect blood cells using SAM.""" |
| | image = preprocess_image(image) |
| | predictor.set_image(image) |
| |
|
| | |
| | masks, _, _ = predictor.predict( |
| | point_coords=None, |
| | point_labels=None, |
| | multimask_output=True |
| | ) |
| |
|
| | contours_list = [] |
| | features = [] |
| | for i, mask in enumerate(masks): |
| | mask = mask.astype(np.uint8) * 255 |
| | contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| |
|
| | for j, contour in enumerate(contours, 1): |
| | area = cv2.contourArea(contour) |
| | perimeter = cv2.arcLength(contour, True) |
| | circularity = 4 * np.pi * area / (perimeter * perimeter) if perimeter > 0 else 0 |
| |
|
| | if 100 < area < 5000 and circularity > 0.7: |
| | M = cv2.moments(contour) |
| | if M["m00"] != 0: |
| | cx = int(M["m10"] / M["m00"]) |
| | cy = int(M["m01"] / M["m00"]) |
| | features.append({ |
| | 'label': f"{i}-{j}", 'area': area, 'perimeter': perimeter, |
| | 'circularity': circularity, 'centroid_x': cx, 'centroid_y': cy |
| | }) |
| | contours_list.append(contour) |
| |
|
| | return contours_list, features, masks |
| |
|
| | def process_image(image): |
| | if image is None: |
| | return None, None, None, None |
| |
|
| | contours, features, masks = detect_blood_cells(image) |
| | vis_img = image.copy() |
| |
|
| | for feature in features: |
| | contour = contours[int(feature['label'].split('-')[1]) - 1] |
| | cv2.drawContours(vis_img, [contour], -1, (0, 255, 0), 2) |
| | cv2.putText(vis_img, str(feature['label']), (feature['centroid_x'], feature['centroid_y']), |
| | cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) |
| |
|
| | df = pd.DataFrame(features) |
| | return vis_img, masks[0], df |
| |
|
| | def analyze(image): |
| | vis_img, mask, df = process_image(image) |
| |
|
| | plt.style.use('dark_background') |
| | fig, axes = plt.subplots(1, 2, figsize=(12, 5)) |
| |
|
| | if not df.empty: |
| | axes[0].hist(df['area'], bins=20, color='cyan', edgecolor='black') |
| | axes[0].set_title('Cell Size Distribution') |
| |
|
| | axes[1].scatter(df['area'], df['circularity'], alpha=0.6, c='magenta') |
| | axes[1].set_title('Area vs Circularity') |
| |
|
| | return vis_img, mask, fig, df |
| |
|
| | |
| | demo = gr.Interface(fn=analyze, inputs=gr.Image(type="numpy"), outputs=[gr.Image(), gr.Image(), gr.Plot(), gr.Dataframe()]) |
| | demo.launch() |
| |
|