Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| from typing import List, Dict, Tuple | |
| import matplotlib.colors as mpl_colors | |
| import pandas as pd | |
| import seaborn as sns | |
| import shinyswatch | |
| from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui | |
| import os | |
| from transformers import SamModel, SamConfig, SamProcessor | |
| import torch | |
| from PIL import Image | |
| import io | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| sns.set_theme() | |
| dir = Path(__file__).resolve().parent | |
| www_dir = Path(__file__).parent.resolve() / "www" | |
| ### UI ### | |
| app_ui = ui.page_fillable( | |
| shinyswatch.theme.minty(), | |
| ui.layout_sidebar( | |
| ui.sidebar( | |
| ui.input_file("tile_image", "Choose an Image", accept=[".tif", ".tiff", ".png"], multiple=False), | |
| ), | |
| #ui.output_image("uploaded_image"), # display the uploaded sidewalk tile image, for some reason doesn't work on all accepted files | |
| ui.output_plot("prediction_plots", fill=True), | |
| ui.output_ui("value_boxes"), | |
| ui.output_plot("scatter", fill=True), | |
| ui.help_text( | |
| "Project by ", | |
| ui.a("@agoluoglu", href="https://github.com/agoluoglu"), | |
| class_="text-end", | |
| ), | |
| ), | |
| ) | |
| ### HELPER FUNCTIONS ### | |
| def bytes_to_pil_image(bytes): | |
| # Create a BytesIO object from the bytes | |
| bytes_io = io.BytesIO(bytes) | |
| # Open the BytesIO object as an Image, crop to square, resize to 256 | |
| image = Image.open(bytes_io).convert("RGB") | |
| w, h = image.size | |
| dim = min(w, h) | |
| image = image.crop((0, 0, dim, dim)) | |
| image = image.resize((256, 256)) | |
| return image | |
| def load_model(): | |
| """ Get Model """ | |
| # Load the model configuration | |
| model_config = SamConfig.from_pretrained("facebook/sam-vit-base") | |
| processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
| # Create an instance of the model architecture with the loaded configuration | |
| model = SamModel(config=model_config) | |
| # Update the model by loading the weights from saved file | |
| model_state_dict = torch.load(str(dir / "checkpoint.pth"), map_location=torch.device('cpu')) | |
| model.load_state_dict(model_state_dict) | |
| # set the device to cuda if available, otherwise use cpu | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| return model, processor, device | |
| def show_mask(mask, ax, random_color=False): | |
| if random_color: | |
| color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
| else: | |
| color = np.array([30/255, 144/255, 255/255, 0.6]) | |
| h, w = mask.shape[-2:] | |
| mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
| ax.imshow(mask_image) | |
| def generate_input_points(image, grid_size=10): | |
| """ | |
| input_points (torch.FloatTensor of shape (batch_size, num_points, 2)) — | |
| Input 2D spatial points, this is used by the prompt encoder to encode the prompt. | |
| Generally yields to much better results. The points can be obtained by passing a | |
| list of list of list to the processor that will create corresponding torch tensors | |
| of dimension 4. The first dimension is the image batch size, the second dimension | |
| is the point batch size (i.e. how many segmentation masks do we want the model to | |
| predict per input point), the third dimension is the number of points per segmentation | |
| mask (it is possible to pass multiple points for a single mask), and the last dimension | |
| is the x (vertical) and y (horizontal) coordinates of the point. If a different number | |
| of points is passed either for each image, or for each mask, the processor will create | |
| “PAD” points that will correspond to the (0, 0) coordinate, and the computation of the | |
| embedding will be skipped for these points using the labels. | |
| """ | |
| # Get the dimensions of the image | |
| array_size = max(image.width, image.height) | |
| # Generate the grid points | |
| x = np.linspace(0, array_size-1, grid_size) | |
| y = np.linspace(0, array_size-1, grid_size) | |
| # Generate a grid of coordinates | |
| xv, yv = np.meshgrid(x, y) | |
| # Convert the numpy arrays to lists | |
| xv_list = xv.tolist() | |
| yv_list = yv.tolist() | |
| # Combine the x and y coordinates into a list of list of lists | |
| input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)] | |
| #We need to reshape our nxn grid to the expected shape of the input_points tensor | |
| # (batch_size, point_batch_size, num_points_per_image, 2), | |
| # where the last dimension of 2 represents the x and y coordinates of each point. | |
| #batch_size: The number of images you're processing at once. | |
| #point_batch_size: The number of point sets you have for each image. | |
| #num_points_per_image: The number of points in each set. | |
| input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2) | |
| return input_points | |
| ### SERVER ### | |
| def server(input: Inputs, output: Outputs, session: Session): | |
| # set model, processor, device once | |
| model, processor, device = load_model() | |
| def uploaded_image_path() -> str: | |
| """Returns the path to the uploaded image""" | |
| if input.tile_image() is not None: | |
| print ("IMAGE PATH!!!!!!", input.tile_image()[0]['datapath']) | |
| return input.tile_image()[0]['datapath'] # Assuming multiple=False | |
| else: | |
| return "" # No image uploaded | |
| # for some reason below function does not work on all accepted files | |
| # works on one screenshot that was converted to .tif but not another *shrug* | |
| # @render.image | |
| # def uploaded_image(): | |
| # """Displays the uploaded image""" | |
| # img_src = uploaded_image_path() | |
| # if img_src: | |
| # img: ImgData = {"src": str(img_src), "width": "200px"} | |
| # print("IMAGE", img) | |
| # return img | |
| # else: | |
| # return None # Return an empty string if no image is uploaded | |
| def process_image(): | |
| """Processes the uploaded image, loads the model, and evaluates to get predictions""" | |
| """ Get Image """ | |
| img_src = uploaded_image_path() | |
| # Read the image bytes from the file | |
| with open(img_src, 'rb') as f: | |
| image_bytes = f.read() | |
| # Convert the image bytes to a PIL Image | |
| image = bytes_to_pil_image(image_bytes) | |
| """ Prepare Inputs """ | |
| # get input points prompt (grid of points) | |
| input_points = generate_input_points(image) | |
| # prepare image and prompt for the model | |
| inputs = processor(image, input_points=input_points, return_tensors="pt") | |
| # # remove batch dimension which the processor adds by default | |
| # inputs = {k:v.squeeze(0) for k,v in inputs.items()} | |
| # Move the input tensor to the GPU if it's not already there | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| """ Get Predictions """ | |
| # forward pass | |
| with torch.no_grad(): | |
| outputs = model(**inputs, multimask_output=False) | |
| # apply sigmoid | |
| prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) | |
| # convert soft mask to hard mask | |
| prob = prob.cpu().numpy().squeeze() | |
| prediction = (prob > 0.5).astype(np.uint8) | |
| # Return the processed result | |
| return image, prob, prediction | |
| def get_predictions(): | |
| """Processes the image when uploaded to get predictions""" | |
| if input.tile_image() is not None: | |
| return process_image() | |
| else: | |
| return None, None, None | |
| def prediction_plots(): | |
| # get prediction results when image is uploaded | |
| image, prob, prediction = get_predictions() | |
| # Check if there are no predictions (i.e., no uploaded image) | |
| if image is None or prob is None or prediction is None: | |
| # Return a placeholder plot or message | |
| fig, ax = plt.subplots() | |
| ax.text(0.5, 0.5, "Upload an image to see predictions. Predictions will take a few moments to load.", fontsize=12, ha="center") | |
| ax.axis("off") # Hide axis | |
| plt.tight_layout() | |
| return fig | |
| fig, axes = plt.subplots(1, 4, figsize=(15, 30)) | |
| # Extract the image data | |
| #image_data = image.cpu().detach().numpy() | |
| # Plot the first image on the left | |
| axes[0].imshow(image) | |
| axes[0].set_title("Image") | |
| # Plot the probability map on the right | |
| axes[1].imshow(prob) | |
| axes[1].set_title("Probability Map") | |
| # Plot the prediction image on the right | |
| axes[2].imshow(prediction) | |
| axes[2].set_title("Prediction") | |
| # Plot the predicted mask on the right | |
| axes[3].imshow(image) | |
| show_mask(prediction, axes[3]) | |
| axes[3].set_title("Predicted Mask") | |
| # Hide axis ticks and labels | |
| for ax in axes: | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| ax.set_xticklabels([]) | |
| ax.set_yticklabels([]) | |
| plt.tight_layout() | |
| return fig | |
| app = App( | |
| app_ui, | |
| server, | |
| static_assets=str(www_dir), | |
| ) | |