Spaces:
Running
Running
:initial commit
Browse files- .gitignore +8 -0
- app.py +121 -0
- inference_beit.py +0 -0
- inference_diffuser.py +0 -0
- inference_resnet.py +167 -0
- inference_sam.py +175 -0
- labels.py +175 -0
- pre-requirements.txt +6 -0
.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
venv/
|
| 3 |
+
images/
|
| 4 |
+
*.pyc
|
| 5 |
+
*.pyo
|
| 6 |
+
*.pyd
|
| 7 |
+
*.swp
|
| 8 |
+
*.__pycache__
|
app.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import subprocess
|
| 3 |
+
import os
|
| 4 |
+
if os.getenv('SYSTEM') == 'spaces':
|
| 5 |
+
subprocess.call('pip install tensorflow==2.9'.split())
|
| 6 |
+
subprocess.call('pip install keras==2.9'.split())
|
| 7 |
+
subprocess.call('pip install git+https://github.com/facebookresearch/segment-anything.git')
|
| 8 |
+
subprocess.call('pip install opencv-python-headless==4.5.5.64'.split())
|
| 9 |
+
subprocess.call('pip install git+https://github.com/cocodataset/panopticapi.git'.split())
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
from huggingface_hub import snapshot_download
|
| 13 |
+
import cv2
|
| 14 |
+
import dotenv
|
| 15 |
+
dotenv.load_dotenv()
|
| 16 |
+
import numpy as np
|
| 17 |
+
import gradio as gr
|
| 18 |
+
import glob
|
| 19 |
+
from inference_sam import segmentation_sam
|
| 20 |
+
|
| 21 |
+
import pathlib
|
| 22 |
+
|
| 23 |
+
if not os.path.exists('images'):
|
| 24 |
+
REPO_ID='Serrelab/image_examples_gradio'
|
| 25 |
+
snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='dataset',local_dir='images')
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def segment_image(input_image):
|
| 29 |
+
img = segmentation_sam(input_image)
|
| 30 |
+
return img
|
| 31 |
+
|
| 32 |
+
def classify_image(input_image, model_name):
|
| 33 |
+
if 'Rock 170' ==model_name:
|
| 34 |
+
from inference_resnet import inference_resnet_finer
|
| 35 |
+
result = inference_resnet_finer(input_image,model_name,n_classes=171)
|
| 36 |
+
return result
|
| 37 |
+
elif 'Mummified 170' ==model_name:
|
| 38 |
+
from inference_resnet import inference_resnet_finer
|
| 39 |
+
result = inference_resnet_finer(input_image,model_name,n_classes=170)
|
| 40 |
+
return result
|
| 41 |
+
if 'Fossils 19' ==model_name:
|
| 42 |
+
from inference_beit import inference_dino
|
| 43 |
+
return inference_dino(input_image,model_name)
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
def find_closest(input_image):
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
|
| 51 |
+
|
| 52 |
+
with gr.Tab(" 19 Classes Support"):
|
| 53 |
+
|
| 54 |
+
with gr.Row():
|
| 55 |
+
with gr.Column():
|
| 56 |
+
input_image = gr.Image(label="Input")
|
| 57 |
+
classify_image_button = gr.Button("Classify Image")
|
| 58 |
+
|
| 59 |
+
with gr.Column():
|
| 60 |
+
segmented_image = gr.outputs.Image(label="SAM output",type='numpy')
|
| 61 |
+
segment_button = gr.Button("Segment Image")
|
| 62 |
+
#classify_segmented_button = gr.Button("Classify Segmented Image")
|
| 63 |
+
|
| 64 |
+
with gr.Column():
|
| 65 |
+
drop_2 = gr.Dropdown(
|
| 66 |
+
["Mummified 170", "Rock 170", "Fossils 19"],
|
| 67 |
+
multiselect=False,
|
| 68 |
+
value=["Rock 170"],
|
| 69 |
+
label="Model",
|
| 70 |
+
interactive=True,
|
| 71 |
+
)
|
| 72 |
+
class_predicted = gr.Label(label='Class Predicted',num_top_classes=10)
|
| 73 |
+
|
| 74 |
+
with gr.Row():
|
| 75 |
+
|
| 76 |
+
paths = sorted(pathlib.Path('images/').rglob('*.jpg'))
|
| 77 |
+
samples=[[path.as_posix()] for path in paths if 'fossils' in str(path) ][:19]
|
| 78 |
+
examples_fossils = gr.Examples(samples, inputs=input_image,examples_per_page=10,label='Fossils Examples from the dataset')
|
| 79 |
+
samples=[[path.as_posix()] for path in paths if 'leaves' in str(path) ][:19]
|
| 80 |
+
examples_leaves = gr.Examples(samples, inputs=input_image,examples_per_page=5,label='Leaves Examples from the dataset')
|
| 81 |
+
|
| 82 |
+
with gr.Accordion("Using Diffuser"):
|
| 83 |
+
with gr.Column():
|
| 84 |
+
prompt = gr.Textbox(lines=1, label="Prompt")
|
| 85 |
+
output_image = gr.Image(label="Output")
|
| 86 |
+
generate_button = gr.Button("Generate Leave")
|
| 87 |
+
with gr.Column():
|
| 88 |
+
class_predicted2 = gr.Label(label='Class Predicted from diffuser')
|
| 89 |
+
classify_button = gr.Button("Classify Image")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
with gr.Accordion("Explanations "):
|
| 93 |
+
gr.Markdown("Computing Explanations from the model")
|
| 94 |
+
with gr.Row():
|
| 95 |
+
original_input = gr.Image(label="Original Frame")
|
| 96 |
+
saliency = gr.Image(label="saliency")
|
| 97 |
+
gradcam = gr.Image(label='gradcam')
|
| 98 |
+
guided_gradcam = gr.Image(label='guided gradcam')
|
| 99 |
+
guided_backprop = gr.Image(label='guided backprop')
|
| 100 |
+
generate_explanations = gr.Button("Generate Explanations")
|
| 101 |
+
|
| 102 |
+
with gr.Accordion('Closest Images'):
|
| 103 |
+
gr.Markdown("Finding the closest images in the dataset")
|
| 104 |
+
with gr.Row():
|
| 105 |
+
closest_image_0 = gr.Image(label='Closest Image')
|
| 106 |
+
closest_image_1 = gr.Image(label='Second Closest Image')
|
| 107 |
+
closest_image_2 = gr.Image(label='Third Closest Image')
|
| 108 |
+
closest_image_3 = gr.Image(label='Forth Closest Image')
|
| 109 |
+
closest_image_4 = gr.Image(label='Fifth Closest Image')
|
| 110 |
+
find_closest_btn = gr.Button("Find Closest Images")
|
| 111 |
+
|
| 112 |
+
segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
|
| 113 |
+
classify_image_button.click(classify_image, inputs=[input_image,drop_2], outputs=class_predicted)
|
| 114 |
+
#classify_segmented_button.click(classify_image, inputs=[segmented_image,drop_2], outputs=class_predicted)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
demo.launch(debug=True)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
|
inference_beit.py
ADDED
|
File without changes
|
inference_diffuser.py
ADDED
|
File without changes
|
inference_resnet.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
|
| 3 |
+
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
|
| 4 |
+
from keras.applications import resnet
|
| 5 |
+
import tensorflow.keras.layers as L
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from typing import Tuple
|
| 11 |
+
from huggingface_hub import snapshot_download
|
| 12 |
+
from labels import lookup_170
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
REPO_ID='Serrelab/fossil_classification_models'
|
| 17 |
+
snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='model',local_dir='model_classification')
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_model(base_arch='Nasnet',weights='imagenet',input_shape=(600,600,3),classes=64500):
|
| 21 |
+
|
| 22 |
+
if base_arch == 'Nasnet':
|
| 23 |
+
base_model = tf.keras.applications.NASNetLarge(
|
| 24 |
+
input_shape=input_shape,
|
| 25 |
+
include_top=False,
|
| 26 |
+
weights=weights,
|
| 27 |
+
input_tensor=None,
|
| 28 |
+
pooling=None,
|
| 29 |
+
|
| 30 |
+
)
|
| 31 |
+
elif base_arch == 'Resnet50v2':
|
| 32 |
+
base_model = tf.keras.applications.ResNet50V2(weights=weights,
|
| 33 |
+
include_top=False,
|
| 34 |
+
pooling='avg',
|
| 35 |
+
input_shape=input_shape)
|
| 36 |
+
elif base_arch == 'Resnet50v2_finer':
|
| 37 |
+
base_model = tf.keras.applications.ResNet50V2(weights=weights,
|
| 38 |
+
include_top=False,
|
| 39 |
+
pooling='avg',
|
| 40 |
+
input_shape=input_shape)
|
| 41 |
+
base_model = resnet.stack2(base_model.output, 512, 2, name="conv6")
|
| 42 |
+
base_model = resnet.stack2(base_model, 512, 2, name="conv7")
|
| 43 |
+
base_model = tf.keras.Model(base_model.input,base_model)
|
| 44 |
+
|
| 45 |
+
model = tf.keras.Sequential([
|
| 46 |
+
base_model,
|
| 47 |
+
L.Dense(classes,activation='softmax')
|
| 48 |
+
])
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
model.compile(optimizer='adam',
|
| 53 |
+
loss='categorical_crossentropy',
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
return model
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_triplet_model(input_shape = (600, 600, 3),
|
| 60 |
+
embedding_units = 256,
|
| 61 |
+
embedding_depth = 2,
|
| 62 |
+
backbone_class=tf.keras.applications.ResNet50V2,
|
| 63 |
+
nb_classes = 19,load_weights=False,finer_model=False,backbone_name ='Resnet50v2'):
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
backbone = backbone_class(input_shape=input_shape, include_top=False)
|
| 67 |
+
if load_weights:
|
| 68 |
+
model = get_model(backbone_name,input_shape=input_shape)
|
| 69 |
+
model.load_weights('/users/irodri15/data/irodri15/Fossils/Models/pretrained-herbarium/Resnet50v2_NO_imagenet_None_best_1600.h5')
|
| 70 |
+
trw = model.layers[0].get_weights()
|
| 71 |
+
backbone.set_weights(trw)
|
| 72 |
+
if finer_model:
|
| 73 |
+
base_model = resnet.stack2(backbone.output, 512, 2, name="conv6")
|
| 74 |
+
base_model = resnet.stack2(base_model, 512, 2, name="conv7")
|
| 75 |
+
backbone = tf.keras.Model(backbone.input,base_model)
|
| 76 |
+
|
| 77 |
+
features = GlobalAveragePooling2D()(backbone.output)
|
| 78 |
+
|
| 79 |
+
embedding_head = features
|
| 80 |
+
for embed_i in range(embedding_depth):
|
| 81 |
+
embedding_head = Dense(embedding_units, activation="relu" if embed_i < embedding_depth-1 else "linear")(embedding_head)
|
| 82 |
+
embedding_head = tf.nn.l2_normalize(embedding_head, -1, epsilon=1e-5)
|
| 83 |
+
|
| 84 |
+
logits_head = Dense(nb_classes)(features)
|
| 85 |
+
|
| 86 |
+
model = tf.keras.Model(backbone.input, [embedding_head, logits_head])
|
| 87 |
+
model.compile(loss='cce',metrics=['accuracy'])
|
| 88 |
+
#model.summary()
|
| 89 |
+
|
| 90 |
+
return model
|
| 91 |
+
|
| 92 |
+
load_size = 600
|
| 93 |
+
crop_size = 600
|
| 94 |
+
def _clever_crop(img: tf.Tensor,
|
| 95 |
+
target_size: Tuple[int]=(128,128),
|
| 96 |
+
grayscale: bool=False
|
| 97 |
+
) -> tf.Tensor:
|
| 98 |
+
"""[summary]
|
| 99 |
+
Args:
|
| 100 |
+
img (tf.Tensor): [description]
|
| 101 |
+
target_size (Tuple[int], optional): [description]. Defaults to (128,128).
|
| 102 |
+
grayscale (bool, optional): [description]. Defaults to False.
|
| 103 |
+
Returns:
|
| 104 |
+
tf.Tensor: [description]
|
| 105 |
+
"""
|
| 106 |
+
maxside = tf.math.maximum(tf.shape(img)[0],tf.shape(img)[1])
|
| 107 |
+
minside = tf.math.minimum(tf.shape(img)[0],tf.shape(img)[1])
|
| 108 |
+
new_img = img
|
| 109 |
+
|
| 110 |
+
if tf.math.divide(maxside,minside) > 1.2:
|
| 111 |
+
repeating = tf.math.floor(tf.math.divide(maxside,minside))
|
| 112 |
+
new_img = img
|
| 113 |
+
if tf.math.equal(tf.shape(img)[1],minside):
|
| 114 |
+
for _ in range(int(repeating)):
|
| 115 |
+
new_img = tf.concat((new_img, img), axis=1)
|
| 116 |
+
|
| 117 |
+
if tf.math.equal(tf.shape(img)[0],minside):
|
| 118 |
+
for _ in range(int(repeating)):
|
| 119 |
+
new_img = tf.concat((new_img, img), axis=0)
|
| 120 |
+
new_img = tf.image.rot90(new_img)
|
| 121 |
+
else:
|
| 122 |
+
new_img = img
|
| 123 |
+
repeating = 0
|
| 124 |
+
img = tf.image.resize(new_img, target_size)
|
| 125 |
+
if grayscale:
|
| 126 |
+
img = tf.image.rgb_to_grayscale(img)
|
| 127 |
+
img = tf.image.grayscale_to_rgb(img)
|
| 128 |
+
|
| 129 |
+
return img,repeating
|
| 130 |
+
|
| 131 |
+
def preprocess(img,size=600):
|
| 132 |
+
img = np.array(img, np.float32) / 255.0
|
| 133 |
+
img = tf.image.resize(img, (size, size))
|
| 134 |
+
return np.array(img, np.float32)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def select_top_n(preds,n=10):
|
| 138 |
+
top_n = np.argsort(preds)[-n:][::-1]
|
| 139 |
+
return top_n
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def parse_results(top_n,logits):
|
| 143 |
+
results = {}
|
| 144 |
+
for n in top_n:
|
| 145 |
+
label = lookup_170[n]
|
| 146 |
+
results[label] = float(logits[n])
|
| 147 |
+
return results
|
| 148 |
+
|
| 149 |
+
def inference_resnet_finer(x,type_model,size=576,n_classes=170,n_top=10):
|
| 150 |
+
|
| 151 |
+
model = get_triplet_model(input_shape = (size, size, 3),
|
| 152 |
+
embedding_units = 256,
|
| 153 |
+
embedding_depth = 2,
|
| 154 |
+
backbone_class=tf.keras.applications.ResNet50V2,
|
| 155 |
+
nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
|
| 156 |
+
if type_model=='Mummified 170':
|
| 157 |
+
model.load_weights('model_classification/mummified-170.h5')
|
| 158 |
+
elif type_model=='Rock 170':
|
| 159 |
+
model.load_weights('model_classification/rock-170.h5')
|
| 160 |
+
else:
|
| 161 |
+
return 'Error'
|
| 162 |
+
cropped = _clever_crop(x,(size,size))[0]
|
| 163 |
+
prep = preprocess(cropped,size=size)
|
| 164 |
+
logits = tf.nn.softmax(model.predict(np.array([prep]))[1][0]).cpu().numpy()
|
| 165 |
+
top_n = select_top_n(logits,n=n_top)
|
| 166 |
+
|
| 167 |
+
return parse_results(top_n,logits)
|
inference_sam.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
torch.cuda.set_per_process_memory_fraction(0.3, device=0)
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
|
| 5 |
+
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
|
| 6 |
+
|
| 7 |
+
from segment_anything import SamPredictor, sam_model_registry
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
from math import ceil
|
| 12 |
+
import os
|
| 13 |
+
from huggingface_hub import snapshot_download
|
| 14 |
+
|
| 15 |
+
REPO_ID='Serrelab/SAM_Leaves'
|
| 16 |
+
snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='model',local_dir='model')
|
| 17 |
+
|
| 18 |
+
sam = sam_model_registry["default"]("model/sam_02-06_dice_mse_0.pth")
|
| 19 |
+
sam.cuda()
|
| 20 |
+
predictor = SamPredictor(sam)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
from torch.nn import functional as F
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def pad_gt(x):
|
| 27 |
+
h, w = x.shape[-2:]
|
| 28 |
+
padh = sam.image_encoder.img_size - h
|
| 29 |
+
padw = sam.image_encoder.img_size - w
|
| 30 |
+
x = F.pad(x, (0, padw, 0, padh))
|
| 31 |
+
return x
|
| 32 |
+
|
| 33 |
+
def preprocess(img):
|
| 34 |
+
|
| 35 |
+
img = np.array(img).astype(np.uint8)
|
| 36 |
+
|
| 37 |
+
#assert img.max() > 127.0
|
| 38 |
+
|
| 39 |
+
img_preprocess = predictor.transform.apply_image(img)
|
| 40 |
+
intermediate_shape = img_preprocess.shape
|
| 41 |
+
|
| 42 |
+
img_preprocess = torch.as_tensor(img_preprocess).cuda()
|
| 43 |
+
img_preprocess = img_preprocess.permute(2, 0, 1).contiguous()[None, :, :, :]
|
| 44 |
+
|
| 45 |
+
img_preprocess = sam.preprocess(img_preprocess)
|
| 46 |
+
if len(intermediate_shape) == 3:
|
| 47 |
+
intermediate_shape = intermediate_shape[:2]
|
| 48 |
+
elif len(intermediate_shape) == 4:
|
| 49 |
+
intermediate_shape = intermediate_shape[1:3]
|
| 50 |
+
|
| 51 |
+
return img_preprocess, intermediate_shape
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def normalize(img):
|
| 55 |
+
img = img - tf.math.reduce_min(img)
|
| 56 |
+
img = img / tf.math.reduce_max(img)
|
| 57 |
+
img = img * 2.0 - 1.0
|
| 58 |
+
return img
|
| 59 |
+
|
| 60 |
+
def resize(img):
|
| 61 |
+
# default resize function for all pi outputs
|
| 62 |
+
return tf.image.resize(img, (SIZE, SIZE), method="bicubic")
|
| 63 |
+
|
| 64 |
+
def smooth_mask(mask, ds=20):
|
| 65 |
+
shape = tf.shape(mask)
|
| 66 |
+
w, h = shape[0], shape[1]
|
| 67 |
+
return tf.image.resize(tf.image.resize(mask, (ds, ds), method="bicubic"), (w, h), method="bicubic")
|
| 68 |
+
|
| 69 |
+
def pi(img, mask):
|
| 70 |
+
img = tf.cast(img, tf.float32)
|
| 71 |
+
|
| 72 |
+
shape = tf.shape(img)
|
| 73 |
+
w, h = tf.cast(shape[0], tf.int64), tf.cast(shape[1], tf.int64)
|
| 74 |
+
|
| 75 |
+
mask = smooth_mask(mask.cpu().numpy().astype(float))
|
| 76 |
+
mask = tf.reduce_mean(mask, -1)
|
| 77 |
+
|
| 78 |
+
img = img * tf.cast(mask > 0.01, tf.float32)[:, :, None]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
img_resize = tf.image.resize(img, (SIZE, SIZE), method="bicubic", antialias=True)
|
| 82 |
+
img_pad = tf.image.resize_with_pad(img, SIZE, SIZE, method="bicubic", antialias=True)
|
| 83 |
+
|
| 84 |
+
# building 2 anchors
|
| 85 |
+
anchors = tf.where(mask > 0.15)
|
| 86 |
+
anchor_xmin = tf.math.reduce_min(anchors[:, 0])
|
| 87 |
+
anchor_xmax = tf.math.reduce_max(anchors[:, 0])
|
| 88 |
+
anchor_ymin = tf.math.reduce_min(anchors[:, 1])
|
| 89 |
+
anchor_ymax = tf.math.reduce_max(anchors[:, 1])
|
| 90 |
+
|
| 91 |
+
if anchor_xmax - anchor_xmin > 50 and anchor_ymax - anchor_ymin > 50:
|
| 92 |
+
|
| 93 |
+
img_anchor_1 = resize(img[anchor_xmin:anchor_xmax, anchor_ymin:anchor_ymax])
|
| 94 |
+
|
| 95 |
+
delta_x = (anchor_xmax - anchor_xmin) // 4
|
| 96 |
+
delta_y = (anchor_ymax - anchor_ymin) // 4
|
| 97 |
+
img_anchor_2 = img[anchor_xmin+delta_x:anchor_xmax-delta_x,
|
| 98 |
+
anchor_ymin+delta_y:anchor_ymax-delta_y]
|
| 99 |
+
img_anchor_2 = resize(img_anchor_2)
|
| 100 |
+
else:
|
| 101 |
+
img_anchor_1 = img_resize
|
| 102 |
+
img_anchor_2 = img_pad
|
| 103 |
+
|
| 104 |
+
# building the anchors max
|
| 105 |
+
anchor_max = tf.where(mask == tf.math.reduce_max(mask))[0]
|
| 106 |
+
anchor_max_x, anchor_max_y = anchor_max[0], anchor_max[1]
|
| 107 |
+
|
| 108 |
+
img_max_zoom1 = img[tf.math.maximum(anchor_max_x-SIZE, 0): tf.math.minimum(anchor_max_x+SIZE, w),
|
| 109 |
+
tf.math.maximum(anchor_max_y-SIZE, 0): tf.math.minimum(anchor_max_y+SIZE, h)]
|
| 110 |
+
|
| 111 |
+
img_max_zoom1 = resize(img_max_zoom1)
|
| 112 |
+
img_max_zoom2 = img[anchor_max_x-SIZE//2:anchor_max_x+SIZE//2,
|
| 113 |
+
anchor_max_y-SIZE//2:anchor_max_y+SIZE//2]
|
| 114 |
+
#img_max_zoom2 = img[tf.math.maximum(anchor_max_x-SIZE//2, 0): tf.math.minimum(anchor_max_x+SIZE//2, w),
|
| 115 |
+
# tf.math.maximum(anchor_max_y-SIZE//2, 0): tf.math.minimum(anchor_max_y+SIZE//2, h)]
|
| 116 |
+
#tf.print(img_max_zoom2.shape)
|
| 117 |
+
#img_max_zoom2 = resize(img_max_zoom2)
|
| 118 |
+
return tf.cast([
|
| 119 |
+
img_resize,
|
| 120 |
+
#img_pad,
|
| 121 |
+
img_anchor_1,
|
| 122 |
+
img_anchor_2,
|
| 123 |
+
img_max_zoom1,
|
| 124 |
+
#img_max_zoom2,
|
| 125 |
+
], tf.float32)
|
| 126 |
+
|
| 127 |
+
def one_step_inference(x):
|
| 128 |
+
if len(x.shape) == 3:
|
| 129 |
+
original_size = x.shape[:2]
|
| 130 |
+
elif len(x.shape) == 4:
|
| 131 |
+
original_size = x.shape[1:3]
|
| 132 |
+
|
| 133 |
+
x, intermediate_shape = preprocess(x)
|
| 134 |
+
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
image_embedding = sam.image_encoder(x)
|
| 137 |
+
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
sparse_embeddings, dense_embeddings = sam.prompt_encoder(points = None, boxes = None,masks = None)
|
| 140 |
+
low_res_masks, iou_predictions = sam.mask_decoder(
|
| 141 |
+
image_embeddings=image_embedding,
|
| 142 |
+
image_pe=sam.prompt_encoder.get_dense_pe(),
|
| 143 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
| 144 |
+
dense_prompt_embeddings=dense_embeddings,
|
| 145 |
+
multimask_output=False,
|
| 146 |
+
)
|
| 147 |
+
if len(x.shape) == 3:
|
| 148 |
+
input_size = tuple(x.shape[:2])
|
| 149 |
+
elif len(x.shape) == 4:
|
| 150 |
+
input_size = tuple(x.shape[-2:])
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
#upscaled_masks = sam.postprocess_masks(low_res_masks, input_size, original_size).cuda()
|
| 154 |
+
mask = F.interpolate(low_res_masks, (1024, 1024))[:, :, :intermediate_shape[0], :intermediate_shape[1]]
|
| 155 |
+
mask = F.interpolate(mask, (original_size[0], original_size[1]))
|
| 156 |
+
|
| 157 |
+
return mask
|
| 158 |
+
|
| 159 |
+
def segmentation_sam(x,SIZE=384):
|
| 160 |
+
|
| 161 |
+
x = tf.image.resize_with_pad(x, SIZE, SIZE)
|
| 162 |
+
predicted_mask = one_step_inference(x)
|
| 163 |
+
fig, ax = plt.subplots()
|
| 164 |
+
img = x.cpu().numpy()
|
| 165 |
+
mask = predicted_mask.cpu().numpy()[0][0]>0.2
|
| 166 |
+
ax.imshow(img)
|
| 167 |
+
ax.imshow(mask, cmap='jet', alpha=0.4)
|
| 168 |
+
plt.savefig('test.png')
|
| 169 |
+
ax.axis('off')
|
| 170 |
+
fig.canvas.draw()
|
| 171 |
+
# Now we can save it to a numpy array.
|
| 172 |
+
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
| 173 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
| 174 |
+
plt.close()
|
| 175 |
+
return data
|
labels.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
lookup_170 = {0: 'Anacardiaceae',
|
| 2 |
+
1: 'Betulaceae',
|
| 3 |
+
2: 'Cornaceae',
|
| 4 |
+
3: 'Cunoniaceae',
|
| 5 |
+
4: 'Euphorbiaceae',
|
| 6 |
+
5: 'Fabaceae',
|
| 7 |
+
6: 'Fagaceae',
|
| 8 |
+
7: 'Juglandaceae',
|
| 9 |
+
8: 'Lauraceae',
|
| 10 |
+
9: 'Malvaceae',
|
| 11 |
+
10: 'Meliaceae',
|
| 12 |
+
11: 'Menispermaceae',
|
| 13 |
+
12: 'Myrtaceae',
|
| 14 |
+
13: 'Proteaceae',
|
| 15 |
+
14: 'Rhamnaceae',
|
| 16 |
+
15: 'Rosaceae',
|
| 17 |
+
16: 'Salicaceae',
|
| 18 |
+
17: 'Sapindaceae',
|
| 19 |
+
18: 'Ulmaceae',
|
| 20 |
+
19: 'Acanthaceae',
|
| 21 |
+
20: 'Achariaceae',
|
| 22 |
+
21: 'Achatocarpaceae',
|
| 23 |
+
22: 'Actinidiaceae',
|
| 24 |
+
23: 'Adoxaceae',
|
| 25 |
+
24: 'Altingiaceae',
|
| 26 |
+
25: 'Amaranthaceae',
|
| 27 |
+
26: 'Ancistrocladaceae',
|
| 28 |
+
27: 'Anisophylleaceae',
|
| 29 |
+
28: 'Annonaceae',
|
| 30 |
+
29: 'Apiaceae',
|
| 31 |
+
30: 'Apocynaceae',
|
| 32 |
+
31: 'Berberidaceae',
|
| 33 |
+
32: 'Bignoniaceae',
|
| 34 |
+
33: 'Bixaceae',
|
| 35 |
+
34: 'Bonnetiaceae',
|
| 36 |
+
35: 'Boraginaceae',
|
| 37 |
+
36: 'Brunelliaceae',
|
| 38 |
+
37: 'Burseraceae',
|
| 39 |
+
38: 'Buxaceae',
|
| 40 |
+
39: 'Calophyllaceae',
|
| 41 |
+
40: 'Calycanthaceae',
|
| 42 |
+
41: 'Campanulaceae',
|
| 43 |
+
42: 'Canellaceae',
|
| 44 |
+
43: 'Cannabaceae',
|
| 45 |
+
44: 'Capparaceae',
|
| 46 |
+
45: 'Caprifoliaceae',
|
| 47 |
+
46: 'Cardiopteridaceae',
|
| 48 |
+
47: 'Caricaceae',
|
| 49 |
+
48: 'Caryocaraceae',
|
| 50 |
+
49: 'Celastraceae',
|
| 51 |
+
50: 'Centroplacaceae',
|
| 52 |
+
51: 'Cercidiphyllaceae',
|
| 53 |
+
52: 'Chloranthaceae',
|
| 54 |
+
53: 'Chrysobalanaceae',
|
| 55 |
+
54: 'Clethraceae',
|
| 56 |
+
55: 'Clusiaceae',
|
| 57 |
+
56: 'Combretaceae',
|
| 58 |
+
57: 'Connaraceae',
|
| 59 |
+
58: 'Coriariaceae',
|
| 60 |
+
59: 'Crassulaceae',
|
| 61 |
+
60: 'Crossosomataceae',
|
| 62 |
+
61: 'Cucurbitaceae',
|
| 63 |
+
62: 'Dichapetalaceae',
|
| 64 |
+
63: 'Dilleniaceae',
|
| 65 |
+
64: 'Dipterocarpaceae',
|
| 66 |
+
65: 'Ebenaceae',
|
| 67 |
+
66: 'Elaeocarpaceae',
|
| 68 |
+
67: 'Ericaceae',
|
| 69 |
+
68: 'Erythroxylaceae',
|
| 70 |
+
69: 'Escalloniaceae',
|
| 71 |
+
70: 'Eucommiaceae',
|
| 72 |
+
71: 'Garryaceae',
|
| 73 |
+
72: 'Gentianaceae',
|
| 74 |
+
73: 'Geraniaceae',
|
| 75 |
+
74: 'Gesneriaceae',
|
| 76 |
+
75: 'Gnetaceae',
|
| 77 |
+
76: 'Grossulariaceae',
|
| 78 |
+
77: 'Gunneraceae',
|
| 79 |
+
78: 'Hamamelidaceae',
|
| 80 |
+
79: 'Hernandiaceae',
|
| 81 |
+
80: 'Humiriaceae',
|
| 82 |
+
81: 'Hydrangeaceae',
|
| 83 |
+
82: 'Hypericaceae',
|
| 84 |
+
83: 'Icacinaceae',
|
| 85 |
+
84: 'Irvingiaceae',
|
| 86 |
+
85: 'Iteaceae',
|
| 87 |
+
86: 'Ixonanthaceae',
|
| 88 |
+
87: 'Lamiaceae',
|
| 89 |
+
88: 'Lardizabalaceae',
|
| 90 |
+
89: 'Lecythidaceae',
|
| 91 |
+
90: 'Liliaceae',
|
| 92 |
+
91: 'Linaceae',
|
| 93 |
+
92: 'Loganiaceae',
|
| 94 |
+
93: 'Loranthaceae',
|
| 95 |
+
94: 'Lythraceae',
|
| 96 |
+
95: 'Magnoliaceae',
|
| 97 |
+
96: 'Malpighiaceae',
|
| 98 |
+
97: 'Marantaceae',
|
| 99 |
+
98: 'Marcgraviaceae',
|
| 100 |
+
99: 'Melastomataceae',
|
| 101 |
+
100: 'Melianthaceae',
|
| 102 |
+
101: 'Monimiaceae',
|
| 103 |
+
102: 'Moraceae',
|
| 104 |
+
103: 'Myricaceae',
|
| 105 |
+
104: 'Myristicaceae',
|
| 106 |
+
105: 'Nitrariaceae',
|
| 107 |
+
106: 'Nothofagaceae',
|
| 108 |
+
107: 'Nyctaginaceae',
|
| 109 |
+
108: 'Ochnaceae',
|
| 110 |
+
109: 'Olacaceae',
|
| 111 |
+
110: 'Oleaceae',
|
| 112 |
+
111: 'Onagraceae',
|
| 113 |
+
112: 'Opiliaceae',
|
| 114 |
+
113: 'Orchidaceae',
|
| 115 |
+
114: 'Orobanchaceae',
|
| 116 |
+
115: 'Oxalidaceae',
|
| 117 |
+
116: 'Pandaceae',
|
| 118 |
+
117: 'Papaveraceae',
|
| 119 |
+
118: 'Paracryphiaceae',
|
| 120 |
+
119: 'Passifloraceae',
|
| 121 |
+
120: 'Pedaliaceae',
|
| 122 |
+
121: 'Penaeaceae',
|
| 123 |
+
122: 'Pentaphylacaceae',
|
| 124 |
+
123: 'Peridiscaceae',
|
| 125 |
+
124: 'Phyllanthaceae',
|
| 126 |
+
125: 'Phytolaccaceae',
|
| 127 |
+
126: 'Picramniaceae',
|
| 128 |
+
127: 'Picrodendraceae',
|
| 129 |
+
128: 'Piperaceae',
|
| 130 |
+
129: 'Pittosporaceae',
|
| 131 |
+
130: 'Platanaceae',
|
| 132 |
+
131: 'Polemoniaceae',
|
| 133 |
+
132: 'Polygalaceae',
|
| 134 |
+
133: 'Polygonaceae',
|
| 135 |
+
134: 'Primulaceae',
|
| 136 |
+
135: 'Ranunculaceae',
|
| 137 |
+
136: 'Rhabdodendraceae',
|
| 138 |
+
137: 'Rhizophoraceae',
|
| 139 |
+
138: 'Rubiaceae',
|
| 140 |
+
139: 'Rutaceae',
|
| 141 |
+
140: 'Sabiaceae',
|
| 142 |
+
141: 'Santalaceae',
|
| 143 |
+
142: 'Sapotaceae',
|
| 144 |
+
143: 'Sarcolaenaceae',
|
| 145 |
+
144: 'Saxifragaceae',
|
| 146 |
+
145: 'Schisandraceae',
|
| 147 |
+
146: 'Schoepfiaceae',
|
| 148 |
+
147: 'Scrophulariaceae',
|
| 149 |
+
148: 'Simaroubaceae',
|
| 150 |
+
149: 'Siparunaceae',
|
| 151 |
+
150: 'Smilacaceae',
|
| 152 |
+
151: 'Solanaceae',
|
| 153 |
+
152: 'Sphaerosepalaceae',
|
| 154 |
+
153: 'Stachyuraceae',
|
| 155 |
+
154: 'Staphyleaceae',
|
| 156 |
+
155: 'Stegnospermataceae',
|
| 157 |
+
156: 'Stemonuraceae',
|
| 158 |
+
157: 'Styracaceae',
|
| 159 |
+
158: 'Symplocaceae',
|
| 160 |
+
159: 'Theaceae',
|
| 161 |
+
160: 'Thymelaeaceae',
|
| 162 |
+
161: 'Trigoniaceae',
|
| 163 |
+
162: 'Trochodendraceae',
|
| 164 |
+
163: 'Urticaceae',
|
| 165 |
+
164: 'Verbenaceae',
|
| 166 |
+
165: 'Violaceae',
|
| 167 |
+
166: 'Vitaceae',
|
| 168 |
+
167: 'Vochysiaceae',
|
| 169 |
+
168: 'Winteraceae',
|
| 170 |
+
169: 'Zygophyllaceae',
|
| 171 |
+
170:'Araceae'}
|
| 172 |
+
|
| 173 |
+
dict_lu ={}
|
| 174 |
+
for i in range(171):
|
| 175 |
+
dict_lu[i] = lookup_170[i]
|
pre-requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==1.22.4
|
| 2 |
+
opencv-python-headless==4.5.5.64
|
| 3 |
+
openmim==0.1.5
|
| 4 |
+
torch==1.11.0
|
| 5 |
+
torchvision==0.12.0
|
| 6 |
+
tensorflow==2.8
|