Spaces:
Running
Running
| import torch | |
| import tensorflow as tf | |
| device = torch.device("cpu") | |
| print(f"Torch device: {device}") | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # if device.type == "cuda": | |
| # torch.cuda.set_per_process_memory_fraction(0.3, device=device.index if device.index is not None else 0) | |
| # else: | |
| # device = "cpu" | |
| # print(f"Torch device: {device}") | |
| tf.config.set_visible_devices([], 'GPU') | |
| # gpu_devices = tf.config.experimental.list_physical_devices('GPU') | |
| # if gpu_devices: | |
| # tf.config.experimental.set_memory_growth(gpu_devices[0], True) | |
| # else: | |
| # print(f"TensorFlow device: {gpu_devices}") | |
| from segment_anything import SamPredictor, sam_model_registry | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| import numpy as np | |
| from math import ceil | |
| import os | |
| from huggingface_hub import snapshot_download | |
| if not os.path.exists('model'): | |
| REPO_ID='Serrelab/SAM_Leaves' | |
| token = os.environ.get('READ_TOKEN') | |
| print(f"Read token:{token}") | |
| if token is None: | |
| print("warning! A read token in env variables is needed for authentication.") | |
| snapshot_download(repo_id=REPO_ID, token=token,repo_type='model',local_dir='model') | |
| original_torch_load = torch.load | |
| def patched_torch_load(*args, **kwargs): | |
| kwargs['map_location'] = device | |
| return original_torch_load(*args, **kwargs) | |
| torch.load = patched_torch_load | |
| model_path = os.path.join('model', 'sam_02-06_dice_mse_0.pth') | |
| sam = sam_model_registry["default"](model_path) | |
| sam.to(device) #sam.cuda() | |
| predictor = SamPredictor(sam) | |
| torch.load = original_torch_load | |
| from torch.nn import functional as F | |
| def pad_gt(x): | |
| h, w = x.shape[-2:] | |
| padh = sam.image_encoder.img_size - h | |
| padw = sam.image_encoder.img_size - w | |
| x = F.pad(x, (0, padw, 0, padh)) | |
| return x | |
| def preprocess(img): | |
| img = np.array(img).astype(np.uint8) | |
| #assert img.max() > 127.0 | |
| img_preprocess = predictor.transform.apply_image(img) | |
| intermediate_shape = img_preprocess.shape | |
| img_preprocess = torch.as_tensor(img_preprocess).to(device) #torch.as_tensor(img_preprocess).cuda() | |
| img_preprocess = img_preprocess.permute(2, 0, 1).contiguous()[None, :, :, :] | |
| img_preprocess = sam.preprocess(img_preprocess) | |
| if len(intermediate_shape) == 3: | |
| intermediate_shape = intermediate_shape[:2] | |
| elif len(intermediate_shape) == 4: | |
| intermediate_shape = intermediate_shape[1:3] | |
| return img_preprocess, intermediate_shape | |
| def normalize(img): | |
| img = img - tf.math.reduce_min(img) | |
| img = img / tf.math.reduce_max(img) | |
| img = img * 2.0 - 1.0 | |
| return img | |
| def resize(img): | |
| # default resize function for all pi outputs | |
| return tf.image.resize(img, (SIZE, SIZE), method="bicubic") | |
| def smooth_mask(mask, ds=20): | |
| shape = tf.shape(mask) | |
| w, h = shape[0], shape[1] | |
| return tf.image.resize(tf.image.resize(mask, (ds, ds), method="bicubic"), (w, h), method="bicubic") | |
| def pi(img, mask): | |
| img = tf.cast(img, tf.float32) | |
| shape = tf.shape(img) | |
| w, h = tf.cast(shape[0], tf.int64), tf.cast(shape[1], tf.int64) | |
| mask = smooth_mask(mask.cpu().numpy().astype(float)) | |
| mask = tf.reduce_mean(mask, -1) | |
| img = img * tf.cast(mask > 0.01, tf.float32)[:, :, None] | |
| img_resize = tf.image.resize(img, (SIZE, SIZE), method="bicubic", antialias=True) | |
| img_pad = tf.image.resize_with_pad(img, SIZE, SIZE, method="bicubic", antialias=True) | |
| # building 2 anchors | |
| anchors = tf.where(mask > 0.15) | |
| anchor_xmin = tf.math.reduce_min(anchors[:, 0]) | |
| anchor_xmax = tf.math.reduce_max(anchors[:, 0]) | |
| anchor_ymin = tf.math.reduce_min(anchors[:, 1]) | |
| anchor_ymax = tf.math.reduce_max(anchors[:, 1]) | |
| if anchor_xmax - anchor_xmin > 50 and anchor_ymax - anchor_ymin > 50: | |
| img_anchor_1 = resize(img[anchor_xmin:anchor_xmax, anchor_ymin:anchor_ymax]) | |
| delta_x = (anchor_xmax - anchor_xmin) // 4 | |
| delta_y = (anchor_ymax - anchor_ymin) // 4 | |
| img_anchor_2 = img[anchor_xmin+delta_x:anchor_xmax-delta_x, | |
| anchor_ymin+delta_y:anchor_ymax-delta_y] | |
| img_anchor_2 = resize(img_anchor_2) | |
| else: | |
| img_anchor_1 = img_resize | |
| img_anchor_2 = img_pad | |
| # building the anchors max | |
| anchor_max = tf.where(mask == tf.math.reduce_max(mask))[0] | |
| anchor_max_x, anchor_max_y = anchor_max[0], anchor_max[1] | |
| img_max_zoom1 = img[tf.math.maximum(anchor_max_x-SIZE, 0): tf.math.minimum(anchor_max_x+SIZE, w), | |
| tf.math.maximum(anchor_max_y-SIZE, 0): tf.math.minimum(anchor_max_y+SIZE, h)] | |
| img_max_zoom1 = resize(img_max_zoom1) | |
| img_max_zoom2 = img[anchor_max_x-SIZE//2:anchor_max_x+SIZE//2, | |
| anchor_max_y-SIZE//2:anchor_max_y+SIZE//2] | |
| #img_max_zoom2 = img[tf.math.maximum(anchor_max_x-SIZE//2, 0): tf.math.minimum(anchor_max_x+SIZE//2, w), | |
| # tf.math.maximum(anchor_max_y-SIZE//2, 0): tf.math.minimum(anchor_max_y+SIZE//2, h)] | |
| #tf.print(img_max_zoom2.shape) | |
| #img_max_zoom2 = resize(img_max_zoom2) | |
| return tf.cast([ | |
| img_resize, | |
| #img_pad, | |
| img_anchor_1, | |
| img_anchor_2, | |
| img_max_zoom1, | |
| #img_max_zoom2, | |
| ], tf.float32) | |
| def one_step_inference(x): | |
| if len(x.shape) == 3: | |
| original_size = x.shape[:2] | |
| elif len(x.shape) == 4: | |
| original_size = x.shape[1:3] | |
| x, intermediate_shape = preprocess(x) | |
| with torch.no_grad(): | |
| image_embedding = sam.image_encoder(x) | |
| with torch.no_grad(): | |
| sparse_embeddings, dense_embeddings = sam.prompt_encoder(points = None, boxes = None,masks = None) | |
| low_res_masks, iou_predictions = sam.mask_decoder( | |
| image_embeddings=image_embedding, | |
| image_pe=sam.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| multimask_output=False, | |
| ) | |
| if len(x.shape) == 3: | |
| input_size = tuple(x.shape[:2]) | |
| elif len(x.shape) == 4: | |
| input_size = tuple(x.shape[-2:]) | |
| #upscaled_masks = sam.postprocess_masks(low_res_masks, input_size, original_size).cuda() | |
| mask = F.interpolate(low_res_masks, (1024, 1024))[:, :, :intermediate_shape[0], :intermediate_shape[1]] | |
| mask = F.interpolate(mask, (original_size[0], original_size[1])) | |
| return mask.to(device) #mask | |
| def segmentation_sam(x,SIZE=384): | |
| x = tf.image.resize_with_pad(x, SIZE, SIZE) | |
| predicted_mask = one_step_inference(x) | |
| fig, ax = plt.subplots() | |
| img = x.cpu().numpy() | |
| mask = predicted_mask.cpu().numpy()[0][0]>0.2 | |
| ax.imshow(img) | |
| ax.imshow(mask, cmap='jet', alpha=0.4) | |
| plt.savefig('test.png') | |
| ax.axis('off') | |
| fig.canvas.draw() | |
| # Now we can save it to a numpy array. | |
| data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
| data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| plt.close() | |
| return data | |