Spaces:
Running
Running
File size: 8,278 Bytes
a5316e5 d563836 d34103c d563836 544f914 a5316e5 8da738a d34103c 544f914 6176ef8 29a9417 52efc69 20d36fe 52efc69 906d487 52efc69 ccf126e 52efc69 ccf126e 20d36fe ccf126e b1a0d53 d34103c b1a0d53 c82104c b1a0d53 6f59e3c aa09a05 60270f6 8da738a 47d0212 dae2ae1 47d0212 e728893 dae2ae1 47d0212 e728893 2c6ecb2 20d36fe 29a9417 2c6ecb2 6176ef8 4e59324 b1a0d53 aa09a05 24390e2 4e59324 3218b1a 4e59324 24390e2 4e59324 d34103c 4e59324 aa09a05 4e59324 aa09a05 4e59324 3218b1a 4e59324 29a9417 3218b1a 29a9417 3218b1a 4e59324 25e534d dae2ae1 25e534d dae2ae1 25e534d 47d0212 52efc69 20d36fe 6268f11 20d36fe 6268f11 29a9417 6176ef8 5282aca 7145ecb 1e09a50 f30d0ea 5282aca d3e6e3b b8df8bd 53e71ae 142304a 47d0212 53e71ae b5d9907 8e84211 142304a d3c40d6 b1677f2 53e71ae 5282aca d3e6e3b 5282aca 15ecc3d 8effe15 3cb6c3b b5d9907 8e84211 142304a d3c40d6 3d8e96d a5316e5 736419a 4e59324 a5316e5 8954378 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import gradio as gr
from transformers import pipeline
import requests
import pygbif
from bs4 import BeautifulSoup
import pandas as pd
classification_model = pipeline("text-classification", model="models/text_classification_model", tokenizer="models/text_classification_model", top_k=5)
masking_model = pipeline("fill-mask", model="models/fill_mask_model", tokenizer="models/fill_mask_model", top_k=100)
eunis_habitats = pd.read_excel('data/eunis_habitats.xlsx')
image_not_found = gr.Image("https://img.freepik.com/premium-vector/file-folder-mascot-character-design-vector_166742-4413.jpg")
def return_image(task, label):
image = None
if task == "classification":
floraveg_url = f"https://floraveg.eu/habitat/overview/{label}"
floraveg_tag = "https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/"
elif task == "masking":
floraveg_url = f"https://floraveg.eu/taxon/overview/{label}"
floraveg_tag = "https://files.ibot.cas.cz/cevs/images/taxa/large/"
response = requests.get(floraveg_url)
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
img_tag = soup.find('img', src=lambda x: x and x.startswith(floraveg_tag))
if img_tag:
image_url = img_tag['src']
image = gr.Image(value=image_url)
return image
def gbif_normalization(text):
all_species = text.split(',')
all_species = [species.strip() for species in all_species]
species_gbif = []
for species in all_species:
gbif_match_result = pygbif.species.name_backbone(species, taxonRank="SPECIES")
if 'usage' in gbif_match_result:
r = gbif_match_result["usage"]["canonicalName"]
else:
r = species
species_gbif.append(r)
text = ", ".join(species_gbif)
text = text.lower()
return text
def classification(text, k):
text = gbif_normalization(text)
result = classification_model(text)
habitat_labels = [res['label'] for res in result[0][:k]]
habitat_name = eunis_habitats[eunis_habitats['EUNIS 2020 code'] == habitat_labels[0]]['EUNIS-2021 habitat name'].values[0]
if k == 1:
text = f"This vegetation plot probably belongs to the habitat type {habitat_labels[0]}."
text += f"\nThis habitat type is named '{habitat_name}'."
elif k == 2:
text = f"This vegetation plot probably belongs to the habitat type {', '.join(habitat_labels[:-1])} or {habitat_labels[-1]}."
text += f"\nThe most likely habitat type (i.e., {habitat_labels[0]}) is named '{habitat_name}'."
else:
text = f"This vegetation plot probably belongs to the habitat type {', '.join(habitat_labels[:-1])}, or {habitat_labels[-1]}."
text += f"\nThe most likely habitat type (i.e., {habitat_labels[0]}) is named '{habitat_name}'."
image = return_image("classification", habitat_labels[0])
if image is not None:
text += "\nBelow is an example of this habitat type taken from the website FloraVEG."
else:
text += f"\nNo image found for this habitat type."
image = image_not_found
return text, image
def masking(text, k):
text = gbif_normalization(text)
text_split = text.split(', ')
best_predictions = []
for _ in range(k):
max_score = 0
best_prediction = None
best_position = None
best_sentence = None
for i in range(len(text_split) + 1):
masked_text = ', '.join(text_split[:i] + ['[MASK]'] + text_split[i:])
j = 0
while True:
prediction = masking_model(masked_text)[j]
species = prediction['token_str']
if species in text_split or species in best_predictions:
j += 1
else:
break
score = prediction['score']
sentence = prediction['sequence']
if score > max_score:
max_score = score
best_prediction = species
best_position = i
best_sentence = sentence
best_predictions.append(best_prediction)
text_split.insert(best_position, best_prediction)
best_positions = [text_split.index(prediction) for prediction in best_predictions]
best_predictions = [s.strip().capitalize() for s in best_predictions]
best_sentence = ", ".join([s.strip().capitalize() for s in best_sentence.split(",")])
if k == 1:
text = f"The most likely missing species is '{best_predictions[0]}' (position {best_positions[0]})."
elif k == 2:
text = f"The most likely missing species are '{best_predictions[0]}' and '{best_predictions[1]}' (positions {best_positions[0]} and {best_positions[1]})."
else:
text = f"The most likely missing species are " + ', '.join(f"'{s}'" for s in best_predictions[:-1]) + f", and '{best_predictions[-1]}' (positions {', '.join(map(str, best_positions[:-1]))}, and {best_positions[-1]})."
text += f"\nThe completed vegetation plot is thus '{best_sentence}'."
image = return_image("masking", best_predictions[0])
if image is not None:
text += f"\nBelow is an image of the first missing species (i.e., {best_predictions[0]}) taken from the website FloraVEG."
else:
text += f"\nNo image found for the first missing species (i.e., {best_predictions[0]})."
image = image_not_found
return text, image
with gr.Blocks() as demo:
gr.Markdown("""<h1 style="text-align: center;">Pl@ntBERT</h1>""")
with gr.Tab("Vegetation plot classification"):
gr.Markdown("""<h3 style="text-align: center;">Habitat identification of vegetation plots!</h3>""")
with gr.Row():
with gr.Column():
species_classification = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
k_classification = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of habitat types to display.")
with gr.Column():
text_classification = gr.Textbox(label="Prediction")
image_classification = gr.Image()
button_classification = gr.Button("Classify")
gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
gr.Examples(
examples=[
["phragmites australis, lemna minor, typha latifolia", 3],
["fagus sylvatica, prenanthes purpurea, abies alba, cardamine bulbifera, orthilia secunda, oxalis acetosella, rubus idaeus", 1]
],
inputs=[species_classification, k_classification],
outputs=[text_classification, image_classification],
fn=classification,
cache_examples=True)
with gr.Tab("Missing species finding"):
gr.Markdown("""<h3 style="text-align: center;">Missing vascular plant species retrieval!</h3>""")
with gr.Row():
with gr.Column():
species_masking = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
k_masking = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of missing species to find.")
with gr.Column():
text_masking = gr.Textbox(label="Prediction")
image_masking = gr.Image()
button_masking = gr.Button("Find")
gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
gr.Examples(
examples=[
["calamagrostis arenaria, medicago marina, pancratium maritimum, thinopyrum junceum", 1],
["trapa natans, lemna minor, phragmites australis, sparganium erectum", 1]
],
inputs=[species_masking, k_masking],
outputs=[text_masking, image_masking],
fn=masking,
cache_examples=True
)
button_classification.click(classification, inputs=[species_classification, k_classification], outputs=[text_classification, image_classification])
button_masking.click(masking, inputs=[species_masking, k_masking], outputs=[text_masking, image_masking])
demo.launch() |