Spaces:
Build error
Build error
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import pydicom | |
| import os | |
| from skimage import transform | |
| import torch | |
| from segment_anything import sam_model_registry | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| import io | |
| import cv2 | |
| import nrrd | |
| from gradio_image_prompter import ImagePrompter | |
| class PointPromptDemo: | |
| def __init__(self, model): | |
| self.model = model | |
| self.model.eval() | |
| self.image = None | |
| self.image_embeddings = None | |
| self.img_size = None | |
| def infer(self, x, y): | |
| coords_1024 = np.array([[[ | |
| x * 1024 / self.img_size[1], | |
| y * 1024 / self.img_size[0] | |
| ]]]) | |
| coords_torch = torch.tensor(coords_1024, dtype=torch.float32).to(self.model.device) | |
| labels_torch = torch.tensor([[1]], dtype=torch.long).to(self.model.device) | |
| point_prompt = (coords_torch, labels_torch) | |
| sparse_embeddings, dense_embeddings = self.model.prompt_encoder( | |
| points=point_prompt, | |
| boxes=None, | |
| masks=None, | |
| ) | |
| low_res_logits, _ = self.model.mask_decoder( | |
| image_embeddings=self.image_embeddings, | |
| image_pe=self.model.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| multimask_output=False, | |
| ) | |
| low_res_probs = torch.sigmoid(low_res_logits) | |
| low_res_pred = F.interpolate( | |
| low_res_probs, | |
| size=self.img_size, | |
| mode='bilinear', | |
| align_corners=False | |
| ) | |
| low_res_pred = low_res_pred.detach().cpu().numpy().squeeze() | |
| seg = np.uint8(low_res_pred > 0.5) | |
| return seg | |
| def set_image(self, image): | |
| self.img_size = image.shape[:2] | |
| if len(image.shape) == 2: | |
| image = np.repeat(image[:,:,None], 3, -1) | |
| self.image = image | |
| image_preprocess = self.preprocess_image(self.image) | |
| with torch.no_grad(): | |
| self.image_embeddings = self.model.image_encoder(image_preprocess) | |
| def preprocess_image(self, image): | |
| img_resize = cv2.resize( | |
| image, | |
| (1024, 1024), | |
| interpolation=cv2.INTER_CUBIC | |
| ) | |
| img_resize = (img_resize - img_resize.min()) / np.clip(img_resize.max() - img_resize.min(), a_min=1e-8, a_max=None) | |
| assert np.max(img_resize)<=1.0 and np.min(img_resize)>=0.0, 'image should be normalized to [0, 1]' | |
| img_tensor = torch.tensor(img_resize).float().permute(2, 0, 1).unsqueeze(0).to(self.model.device) | |
| return img_tensor | |
| def load_image(file_path): | |
| if file_path.endswith(".dcm"): | |
| ds = pydicom.dcmread(file_path) | |
| img = ds.pixel_array | |
| elif file_path.endswith(".nrrd"): | |
| img, _ = nrrd.read(file_path) | |
| else: | |
| img = np.array(Image.open(file_path)) | |
| if len(img.shape) == 2: | |
| img = np.stack((img,)*3, axis=-1) | |
| return img | |
| def visualize(image, mask): | |
| fig, ax = plt.subplots(1, 2, figsize=(10, 5)) | |
| ax[0].imshow(image) | |
| ax[1].imshow(image) | |
| ax[1].imshow(mask, alpha=0.5, cmap="jet") | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png') | |
| plt.close(fig) | |
| buf.seek(0) | |
| pil_img = Image.open(buf) | |
| return pil_img | |
| def process_images(img_dict): | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| img = img_dict['image'] | |
| points = img_dict['points'][0] | |
| if len(points) < 2: | |
| raise ValueError("At least one point is required for ROI selection.") | |
| x, y = points[0], points[1] | |
| model_checkpoint_path = "medsam_point_prompt_flare22.pth" | |
| medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path) | |
| medsam_model = medsam_model.to(device) | |
| medsam_model.eval() | |
| point_prompt_demo = PointPromptDemo(medsam_model) | |
| point_prompt_demo.set_image(img) | |
| mask = point_prompt_demo.infer(x, y) | |
| visualization = visualize(img, mask) | |
| return visualization | |
| iface = gr.Interface( | |
| fn=process_images, | |
| inputs=[ | |
| ImagePrompter(label="Image") | |
| ], | |
| outputs=[ | |
| gr.Image(type="pil", label="Processed Image") | |
| ], | |
| title="ROI Selection with MEDSAM", | |
| description="Upload an image (including NRRD files) and select a point for ROI processing." | |
| ) | |
| iface.launch() | |