Spaces:
Running
Running
cesar.leblanc
commited on
Commit
·
4e59324
1
Parent(s):
142304a
app.py
CHANGED
|
@@ -79,37 +79,50 @@ def classification(text, k):
|
|
| 79 |
image_output = return_habitat_image(habitat_labels[0])
|
| 80 |
return text, image_output
|
| 81 |
|
| 82 |
-
def masking(text):
|
| 83 |
text = gbif_normalization(text)
|
| 84 |
text_split = text.split(', ')
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
image = return_species_image(best_prediction)
|
| 114 |
return text, image
|
| 115 |
|
|
@@ -122,26 +135,27 @@ with gr.Blocks() as demo:
|
|
| 122 |
with gr.Row():
|
| 123 |
with gr.Column():
|
| 124 |
species_classification = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
|
| 125 |
-
k_classification = gr.Slider(1, 5, value=1, label="Top-k", info="Choose the number of top habitats to display.")
|
| 126 |
with gr.Column():
|
| 127 |
-
|
| 128 |
-
|
| 129 |
button_classification = gr.Button("Classify")
|
| 130 |
gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
|
| 131 |
-
gr.Examples([["sparganium erectum, calystegia sepium, persicaria amphibia", 1]], [species_classification, k_classification], [
|
| 132 |
|
| 133 |
with gr.Tab("Missing species finding"):
|
| 134 |
gr.Markdown("""<h3 style="text-align: center;">Finding the missing species!</h3>""")
|
| 135 |
with gr.Row():
|
| 136 |
species_masking = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
|
|
|
|
| 137 |
with gr.Column():
|
| 138 |
-
|
| 139 |
-
|
| 140 |
button_masking = gr.Button("Find")
|
| 141 |
gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
|
| 142 |
-
gr.Examples([["vaccinium myrtillus, dryopteris dilatata, molinia caerulea"]], [species_masking], [
|
| 143 |
|
| 144 |
-
button_classification.click(classification, inputs=[species_classification, k_classification], outputs=[
|
| 145 |
-
button_masking.click(masking, inputs=[species_masking], outputs=[
|
| 146 |
|
| 147 |
demo.launch()
|
|
|
|
| 79 |
image_output = return_habitat_image(habitat_labels[0])
|
| 80 |
return text, image_output
|
| 81 |
|
| 82 |
+
def masking(text, k):
|
| 83 |
text = gbif_normalization(text)
|
| 84 |
text_split = text.split(', ')
|
| 85 |
|
| 86 |
+
best_predictions = []
|
| 87 |
+
best_positions = []
|
| 88 |
+
best_sentences = []
|
| 89 |
+
|
| 90 |
+
for _ in range(k):
|
| 91 |
+
max_score = 0
|
| 92 |
+
best_prediction = None
|
| 93 |
+
best_position = None
|
| 94 |
+
best_sentence = None
|
| 95 |
|
| 96 |
+
for i in range(len(text_split) + 1):
|
| 97 |
+
masked_text = ', '.join(text_split[:i] + ['[MASK]'] + text_split[i:])
|
| 98 |
+
|
| 99 |
+
j = 0
|
| 100 |
+
while True:
|
| 101 |
+
prediction = mask_model(masked_text)[j]
|
| 102 |
+
species = prediction['token_str']
|
| 103 |
+
if species in text_split or species in best_predictions:
|
| 104 |
+
j += 1
|
| 105 |
+
else:
|
| 106 |
+
break
|
| 107 |
|
| 108 |
+
score = prediction['score']
|
| 109 |
+
sentence = prediction['sequence']
|
| 110 |
|
| 111 |
+
if score > max_score:
|
| 112 |
+
max_score = score
|
| 113 |
+
best_prediction = species
|
| 114 |
+
best_position = i
|
| 115 |
+
best_sentence = sentence
|
| 116 |
+
|
| 117 |
+
best_predictions.append(best_prediction)
|
| 118 |
+
best_positions.append(best_position)
|
| 119 |
+
best_sentences.append(best_sentence)
|
| 120 |
+
text_split.insert(best_position, best_prediction)
|
| 121 |
+
if k == 1:
|
| 122 |
+
text = f"The most likely missing species is {best_predictions[0]} (position {best_positions[0]})."
|
| 123 |
+
else:
|
| 124 |
+
text = f"The most likely missing species are {', '.join(best_predictions[:-1])} and {best_predictions[-1]} (positions {', '.join(map(str, best_positions[:-1]))} and {best_positions[-1]})."
|
| 125 |
+
text += f"\nThe new vegetation plot is {best_sentences[-1]}. (see image of the most likely species below)."
|
| 126 |
image = return_species_image(best_prediction)
|
| 127 |
return text, image
|
| 128 |
|
|
|
|
| 135 |
with gr.Row():
|
| 136 |
with gr.Column():
|
| 137 |
species_classification = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
|
| 138 |
+
k_classification = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of top habitats to display.")
|
| 139 |
with gr.Column():
|
| 140 |
+
text_classification = gr.Textbox()
|
| 141 |
+
image_classification = gr.Image()
|
| 142 |
button_classification = gr.Button("Classify")
|
| 143 |
gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
|
| 144 |
+
gr.Examples([["sparganium erectum, calystegia sepium, persicaria amphibia", 1]], [species_classification, k_classification], [text_classification, image_classification], classification, True)
|
| 145 |
|
| 146 |
with gr.Tab("Missing species finding"):
|
| 147 |
gr.Markdown("""<h3 style="text-align: center;">Finding the missing species!</h3>""")
|
| 148 |
with gr.Row():
|
| 149 |
species_masking = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
|
| 150 |
+
k_masking = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of top missing species to find.")
|
| 151 |
with gr.Column():
|
| 152 |
+
text_masking = gr.Textbox()
|
| 153 |
+
image_masking = gr.Image()
|
| 154 |
button_masking = gr.Button("Find")
|
| 155 |
gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
|
| 156 |
+
gr.Examples([["vaccinium myrtillus, dryopteris dilatata, molinia caerulea", 1]], [species_masking, k_masking], [text_masking, image_masking], masking, True)
|
| 157 |
|
| 158 |
+
button_classification.click(classification, inputs=[species_classification, k_classification], outputs=[textclassification, image_classification])
|
| 159 |
+
button_masking.click(masking, inputs=[species_masking, k_masking], outputs=[text_masking, image_masking])
|
| 160 |
|
| 161 |
demo.launch()
|