File size: 8,278 Bytes
a5316e5
 
d563836
d34103c
d563836
544f914
a5316e5
8da738a
d34103c
544f914
 
6176ef8
29a9417
 
52efc69
20d36fe
52efc69
 
 
 
906d487
52efc69
ccf126e
 
 
52efc69
ccf126e
 
20d36fe
ccf126e
 
b1a0d53
 
 
 
 
d34103c
 
 
b1a0d53
 
 
 
 
 
 
c82104c
b1a0d53
6f59e3c
aa09a05
60270f6
8da738a
47d0212
 
dae2ae1
47d0212
e728893
dae2ae1
47d0212
e728893
2c6ecb2
 
20d36fe
 
 
29a9417
2c6ecb2
6176ef8
4e59324
b1a0d53
aa09a05
24390e2
4e59324
 
 
 
 
3218b1a
4e59324
24390e2
4e59324
 
 
 
 
d34103c
4e59324
 
 
 
 
aa09a05
4e59324
 
aa09a05
4e59324
 
 
3218b1a
4e59324
 
29a9417
3218b1a
 
 
29a9417
 
3218b1a
4e59324
25e534d
dae2ae1
25e534d
dae2ae1
25e534d
47d0212
52efc69
20d36fe
6268f11
20d36fe
6268f11
29a9417
6176ef8
 
5282aca
7145ecb
1e09a50
f30d0ea
5282aca
d3e6e3b
b8df8bd
53e71ae
142304a
47d0212
53e71ae
b5d9907
8e84211
142304a
d3c40d6
b1677f2
 
 
 
 
 
 
 
 
53e71ae
5282aca
d3e6e3b
5282aca
15ecc3d
 
8effe15
3cb6c3b
b5d9907
8e84211
142304a
d3c40d6
3d8e96d
 
 
 
 
 
 
 
 
 
a5316e5
736419a
4e59324
a5316e5
8954378
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import gradio as gr
from transformers import pipeline
import requests
import pygbif
from bs4 import BeautifulSoup
import pandas as pd

classification_model = pipeline("text-classification", model="models/text_classification_model", tokenizer="models/text_classification_model", top_k=5)
masking_model = pipeline("fill-mask", model="models/fill_mask_model", tokenizer="models/fill_mask_model", top_k=100)

eunis_habitats = pd.read_excel('data/eunis_habitats.xlsx')

image_not_found = gr.Image("https://img.freepik.com/premium-vector/file-folder-mascot-character-design-vector_166742-4413.jpg")

def return_image(task, label):
    image = None
    if task == "classification":
        floraveg_url = f"https://floraveg.eu/habitat/overview/{label}"
        floraveg_tag = "https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/"
    elif task == "masking":
        floraveg_url = f"https://floraveg.eu/taxon/overview/{label}"
        floraveg_tag = "https://files.ibot.cas.cz/cevs/images/taxa/large/"
    response = requests.get(floraveg_url)
    if response.status_code == 200:
        soup = BeautifulSoup(response.text, 'html.parser')
        img_tag = soup.find('img', src=lambda x: x and x.startswith(floraveg_tag))
        if img_tag:
            image_url = img_tag['src']
            image = gr.Image(value=image_url)
    return image

def gbif_normalization(text):
    all_species = text.split(',')
    all_species = [species.strip() for species in all_species]
    species_gbif = []
    for species in all_species:
        gbif_match_result = pygbif.species.name_backbone(species, taxonRank="SPECIES")
        if 'usage' in gbif_match_result:
            r = gbif_match_result["usage"]["canonicalName"]
        else:
            r = species
        species_gbif.append(r)
    text = ", ".join(species_gbif)
    text = text.lower()
    return text

def classification(text, k):
    text = gbif_normalization(text)
    result = classification_model(text)
    habitat_labels = [res['label'] for res in result[0][:k]]
    habitat_name = eunis_habitats[eunis_habitats['EUNIS 2020 code'] == habitat_labels[0]]['EUNIS-2021 habitat name'].values[0]
    if k == 1:
        text = f"This vegetation plot probably belongs to the habitat type {habitat_labels[0]}."
        text += f"\nThis habitat type is named '{habitat_name}'."
    elif k == 2:
        text = f"This vegetation plot probably belongs to the habitat type {', '.join(habitat_labels[:-1])} or {habitat_labels[-1]}."
        text += f"\nThe most likely habitat type (i.e., {habitat_labels[0]}) is named '{habitat_name}'."
    else:
        text = f"This vegetation plot probably belongs to the habitat type {', '.join(habitat_labels[:-1])}, or {habitat_labels[-1]}."
        text += f"\nThe most likely habitat type (i.e., {habitat_labels[0]}) is named '{habitat_name}'."
    image = return_image("classification", habitat_labels[0])
    if image is not None:
        text += "\nBelow is an example of this habitat type taken from the website FloraVEG."
    else:
        text += f"\nNo image found for this habitat type."
        image = image_not_found
    return text, image

def masking(text, k):
    text = gbif_normalization(text)
    text_split = text.split(', ')
    
    best_predictions = []
    
    for _ in range(k):
        max_score = 0
        best_prediction = None
        best_position = None
        best_sentence = None

        for i in range(len(text_split) + 1):
            masked_text = ', '.join(text_split[:i] + ['[MASK]'] + text_split[i:])
            
            j = 0
            while True:
                prediction = masking_model(masked_text)[j]
                species = prediction['token_str']
                if species in text_split or species in best_predictions:
                    j += 1
                else:
                    break

            score = prediction['score']
            sentence = prediction['sequence']

            if score > max_score:
                max_score = score
                best_prediction = species
                best_position = i
                best_sentence = sentence
        
        best_predictions.append(best_prediction)
        text_split.insert(best_position, best_prediction)
        
    best_positions = [text_split.index(prediction) for prediction in best_predictions]
    best_predictions = [s.strip().capitalize() for s in best_predictions]
    best_sentence = ", ".join([s.strip().capitalize() for s in best_sentence.split(",")])
    
    if k == 1:
        text = f"The most likely missing species is '{best_predictions[0]}' (position {best_positions[0]})."
    elif k == 2:
        text = f"The most likely missing species are '{best_predictions[0]}' and '{best_predictions[1]}' (positions {best_positions[0]} and {best_positions[1]})."
    else:
        text = f"The most likely missing species are " + ', '.join(f"'{s}'" for s in best_predictions[:-1]) + f", and '{best_predictions[-1]}' (positions {', '.join(map(str, best_positions[:-1]))}, and {best_positions[-1]})."
    text += f"\nThe completed vegetation plot is thus '{best_sentence}'."
    image = return_image("masking", best_predictions[0])
    if image is not None:
        text += f"\nBelow is an image of the first missing species (i.e., {best_predictions[0]}) taken from the website FloraVEG."
    else:
        text += f"\nNo image found for the first missing species (i.e., {best_predictions[0]})."
        image = image_not_found
    return text, image

with gr.Blocks() as demo:

    gr.Markdown("""<h1 style="text-align: center;">Pl@ntBERT</h1>""")
    
    with gr.Tab("Vegetation plot classification"):
        gr.Markdown("""<h3 style="text-align: center;">Habitat identification of vegetation plots!</h3>""")
        with gr.Row():
            with gr.Column():
                species_classification = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
                k_classification = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of habitat types to display.")
            with gr.Column():
                text_classification = gr.Textbox(label="Prediction")
                image_classification = gr.Image()
        button_classification = gr.Button("Classify")
        gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
        gr.Examples(
            examples=[
                ["phragmites australis, lemna minor, typha latifolia", 3],
                ["fagus sylvatica, prenanthes purpurea, abies alba, cardamine bulbifera, orthilia secunda, oxalis acetosella, rubus idaeus", 1]
            ], 
            inputs=[species_classification, k_classification],
            outputs=[text_classification, image_classification],
            fn=classification, 
            cache_examples=True)
        
    with gr.Tab("Missing species finding"):
        gr.Markdown("""<h3 style="text-align: center;">Missing vascular plant species retrieval!</h3>""")
        with gr.Row():
            with gr.Column():
                species_masking = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
                k_masking = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of missing species to find.")
            with gr.Column():
                text_masking = gr.Textbox(label="Prediction")
                image_masking = gr.Image()
        button_masking = gr.Button("Find")
        gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
        gr.Examples(
            examples=[
            ["calamagrostis arenaria, medicago marina, pancratium maritimum, thinopyrum junceum", 1],
            ["trapa natans, lemna minor, phragmites australis, sparganium erectum", 1]
            ],
            inputs=[species_masking, k_masking],
            outputs=[text_masking, image_masking], 
            fn=masking, 
            cache_examples=True
        )

    button_classification.click(classification, inputs=[species_classification, k_classification], outputs=[text_classification, image_classification])
    button_masking.click(masking, inputs=[species_masking, k_masking], outputs=[text_masking, image_masking])

demo.launch()