Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,16 +16,12 @@ def load_bounding_boxes(csv_file):
|
|
| 16 |
df = pd.read_csv(csv_file)
|
| 17 |
return df
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
ds = pydicom.dcmread(filepath)
|
| 26 |
-
img = ds.pixel_array
|
| 27 |
-
images.append(img)
|
| 28 |
-
return np.array(images)
|
| 29 |
|
| 30 |
# MedSAM inference function
|
| 31 |
def medsam_inference(medsam_model, img, box, H, W, target_size):
|
|
@@ -64,42 +60,34 @@ def visualize(images, masks, box):
|
|
| 64 |
return buf
|
| 65 |
|
| 66 |
# Main function for Gradio app
|
| 67 |
-
def process_images(csv_file,
|
| 68 |
bounding_boxes = load_bounding_boxes(csv_file)
|
| 69 |
-
|
| 70 |
|
| 71 |
# Initialize MedSAM model
|
| 72 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 73 |
-
medsam_model = sam_model_registry['
|
| 74 |
medsam_model = medsam_model.to(device)
|
| 75 |
medsam_model.eval()
|
| 76 |
|
| 77 |
masks = []
|
|
|
|
| 78 |
for index, row in bounding_boxes.iterrows():
|
| 79 |
-
if index >= len(dicom_images):
|
| 80 |
-
continue # Skip if the index exceeds the number of images
|
| 81 |
-
|
| 82 |
-
image = dicom_images[index]
|
| 83 |
-
H, W = image.shape
|
| 84 |
box = [row['x_min'], row['y_min'], row['x_max'], row['y_max']]
|
| 85 |
-
|
| 86 |
-
mask = medsam_inference(medsam_model, image, box, H, W, target_size)
|
| 87 |
masks.append(mask)
|
|
|
|
| 88 |
|
| 89 |
-
visualizations = visualize(
|
| 90 |
-
|
| 91 |
-
return visualizations, np.array(masks)
|
| 92 |
|
| 93 |
# Set up Gradio interface
|
| 94 |
iface = gr.Interface(
|
| 95 |
fn=process_images,
|
| 96 |
inputs=[
|
| 97 |
gr.File(label="CSV File"),
|
| 98 |
-
gr.File(label="
|
| 99 |
-
outputs=
|
| 100 |
-
gr.Image(type="pil"),
|
| 101 |
-
gr.File(type="numpy")
|
| 102 |
-
]
|
| 103 |
)
|
| 104 |
|
| 105 |
iface.launch()
|
|
|
|
| 16 |
df = pd.read_csv(csv_file)
|
| 17 |
return df
|
| 18 |
|
| 19 |
+
def load_dicom_image(filename):
|
| 20 |
+
if filename.endswith(".dcm"):
|
| 21 |
+
ds = pydicom.dcmread(filename)
|
| 22 |
+
img = ds.pixel_array
|
| 23 |
+
H, W = img.shape
|
| 24 |
+
return np.array(img), H, W
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# MedSAM inference function
|
| 27 |
def medsam_inference(medsam_model, img, box, H, W, target_size):
|
|
|
|
| 60 |
return buf
|
| 61 |
|
| 62 |
# Main function for Gradio app
|
| 63 |
+
def process_images(csv_file, dicom_file):
|
| 64 |
bounding_boxes = load_bounding_boxes(csv_file)
|
| 65 |
+
image, H, W = load_dicom_image(dicom_file)
|
| 66 |
|
| 67 |
# Initialize MedSAM model
|
| 68 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 69 |
+
medsam_model = sam_model_registry['vit_b'](checkpoint="medsam_vit_b.pth") # Ensure the correct path
|
| 70 |
medsam_model = medsam_model.to(device)
|
| 71 |
medsam_model.eval()
|
| 72 |
|
| 73 |
masks = []
|
| 74 |
+
boxes = []
|
| 75 |
for index, row in bounding_boxes.iterrows():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
box = [row['x_min'], row['y_min'], row['x_max'], row['y_max']]
|
| 77 |
+
mask = medsam_inference(medsam_model, image, box, H, W, H) # Assuming target size is the same as the image height
|
|
|
|
| 78 |
masks.append(mask)
|
| 79 |
+
boxes.append(box)
|
| 80 |
|
| 81 |
+
visualizations = visualize([image] * len(masks), masks, boxes)
|
| 82 |
+
return visualizations.getvalue()
|
|
|
|
| 83 |
|
| 84 |
# Set up Gradio interface
|
| 85 |
iface = gr.Interface(
|
| 86 |
fn=process_images,
|
| 87 |
inputs=[
|
| 88 |
gr.File(label="CSV File"),
|
| 89 |
+
gr.File(label="DICOM File")],
|
| 90 |
+
outputs="plot"
|
|
|
|
|
|
|
|
|
|
| 91 |
)
|
| 92 |
|
| 93 |
iface.launch()
|