Spaces:
Running
Running
| import gradio as gr | |
| from transformers import pipeline | |
| import requests | |
| import pygbif | |
| from bs4 import BeautifulSoup | |
| import pandas as pd | |
| classification_model = pipeline("text-classification", model="models/text_classification_model", tokenizer="models/text_classification_model", top_k=5) | |
| masking_model = pipeline("fill-mask", model="models/fill_mask_model", tokenizer="models/fill_mask_model", top_k=100) | |
| eunis_habitats = pd.read_excel('data/eunis_habitats.xlsx') | |
| def return_image(task, label): | |
| image = None | |
| if task == "classification": | |
| floraveg_url = f"https://floraveg.eu/habitat/overview/{label}" | |
| floraveg_tag = "https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/" | |
| elif task == "masking": | |
| floraveg_url = f"https://floraveg.eu/taxon/overview/{label.capitalize()}" | |
| floraveg_tag = "https://files.ibot.cas.cz/cevs/images/taxa/large/" | |
| response = requests.get(floraveg_url) | |
| if response.status_code == 200: | |
| soup = BeautifulSoup(response.text, 'html.parser') | |
| img_tag = soup.find('img', src=lambda x: x and x.startswith(floraveg_tag)) | |
| if img_tag: | |
| image_url = img_tag['src'] | |
| image = gr.Image(value=image_url) | |
| return image | |
| def gbif_normalization(text): | |
| all_species = text.split(',') | |
| all_species = [species.strip() for species in all_species] | |
| species_gbif = [] | |
| for species in all_species: | |
| gbif_match_result = pygbif.species.name_backbone(species, taxonRank="SPECIES") | |
| if 'usage' in gbif_match_result: | |
| r = gbif_match_result["usage"]["canonicalName"] | |
| else: | |
| r = species | |
| species_gbif.append(r) | |
| text = ", ".join(species_gbif) | |
| text = text.lower() | |
| return text | |
| def classification(text, k): | |
| text = gbif_normalization(text) | |
| result = classification_model(text) | |
| habitat_labels = [res['label'] for res in result[0][:k]] | |
| habitat_name = eunis_habitats[eunis_habitats['EUNIS 2020 code'] == habitat_labels[0]]['EUNIS-2021 habitat name'].values[0] | |
| if k == 1: | |
| text = f"This vegetation plot probably belongs to the habitat type {habitat_labels[0]}." | |
| text += f"\nThis habitat type is named '{habitat_name}'." | |
| elif k == 2: | |
| text = f"This vegetation plot probably belongs to the habitat type {', '.join(habitat_labels[:-1])} or {habitat_labels[-1]}." | |
| text += f"\nThe most likely habitat type (i.e., {habitat_labels[0]}) is named '{habitat_name}'." | |
| else: | |
| text = f"This vegetation plot probably belongs to the habitat type {', '.join(habitat_labels[:-1])}, or {habitat_labels[-1]}." | |
| text += f"\nThe most likely habitat type (i.e., {habitat_labels[0]}) is named '{habitat_name}'." | |
| image_output = return_image("classification", habitat_labels[0]) | |
| if image_output is not None: | |
| text += "\nBelow is an example of this habitat type taken from the website FloraVEG." | |
| else: | |
| text += f"\nNo image found for this habitat type." | |
| return text, image_output | |
| def masking(text, k): | |
| text = gbif_normalization(text) | |
| text_split = text.split(', ') | |
| best_predictions = [] | |
| for _ in range(k): | |
| max_score = 0 | |
| best_prediction = None | |
| best_position = None | |
| best_sentence = None | |
| for i in range(len(text_split) + 1): | |
| masked_text = ', '.join(text_split[:i] + ['[MASK]'] + text_split[i:]) | |
| j = 0 | |
| while True: | |
| prediction = masking_model(masked_text)[j] | |
| species = prediction['token_str'] | |
| if species in text_split or species in best_predictions: | |
| j += 1 | |
| else: | |
| break | |
| score = prediction['score'] | |
| sentence = prediction['sequence'] | |
| if score > max_score: | |
| max_score = score | |
| best_prediction = species | |
| best_position = i | |
| best_sentence = sentence | |
| best_predictions.append(best_prediction) | |
| text_split.insert(best_position, best_prediction) | |
| best_positions = [text_split.index(prediction) for prediction in best_predictions] | |
| best_sentence = ", ".join( | |
| [s.strip().capitalize() for s in best_sentence.split(",")] | |
| ) | |
| if k == 1: | |
| text = f"The most likely missing species is {best_predictions[0].capitalize()} (position {best_positions[0]})." | |
| elif k == 2: | |
| text = f"The most likely missing species are {', '.join(best_predictions[:-1].capitalize())} and {best_predictions[-1].capitalize()} (positions {', '.join(map(str, best_positions[:-1]))} and {best_positions[-1]})." | |
| else: | |
| text = f"The most likely missing species are {', '.join(best_predictions[:-1].capitalize())}, and {best_predictions[-1].capitalize()} (positions {', '.join(map(str, best_positions[:-1]))}, and {best_positions[-1]})." | |
| text += f"\nThe completed vegetation plot is thus '{best_sentence}'." | |
| image = return_image("masking", best_predictions[0]) | |
| if image is not None: | |
| text += f"\nBelow is an image of the first missing species (i.e., {best_predictions[0].capitalize()}) taken from the website FloraVEG." | |
| else: | |
| text += f"\nNo image found for the first missing species (i.e., {best_predictions[0].capitalize()})." | |
| return text, image | |
| with gr.Blocks() as demo: | |
| gr.Markdown("""<h1 style="text-align: center;">Pl@ntBERT</h1>""") | |
| with gr.Tab("Vegetation plot classification"): | |
| gr.Markdown("""<h3 style="text-align: center;">Habitat identification of vegetation plots!</h3>""") | |
| with gr.Row(): | |
| with gr.Column(): | |
| species_classification = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.") | |
| k_classification = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of habitat types to display.") | |
| with gr.Column(): | |
| text_classification = gr.Textbox(label="Prediction") | |
| image_classification = gr.Image() | |
| button_classification = gr.Button("Classify") | |
| gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""") | |
| gr.Examples( | |
| examples=[ | |
| ["phragmites australis, lemna minor, typha latifolia", 3], | |
| ["fagus sylvatica, prenanthes purpurea, abies alba, cardamine bulbifera, orthilia secunda, oxalis acetosella, rubus idaeus", 1] | |
| ], | |
| inputs=[species_classification, k_classification], | |
| outputs=[text_classification, image_classification], | |
| fn=classification, | |
| cache_examples=True) | |
| with gr.Tab("Missing species finding"): | |
| gr.Markdown("""<h3 style="text-align: center;">Missing vascular plant species retrieval!</h3>""") | |
| with gr.Row(): | |
| with gr.Column(): | |
| species_masking = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.") | |
| k_masking = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of missing species to find.") | |
| with gr.Column(): | |
| text_masking = gr.Textbox(label="Prediction") | |
| image_masking = gr.Image() | |
| button_masking = gr.Button("Find") | |
| gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""") | |
| gr.Examples( | |
| examples=[ | |
| ["calamagrostis arenaria, medicago marina, pancratium maritimum, thinopyrum junceum", 1], | |
| ["trapa natans, lemna minor, phragmites australis, sparganium erectum", 1] | |
| ], | |
| inputs=[species_masking, k_masking], | |
| outputs=[text_masking, image_masking], | |
| fn=masking, | |
| cache_examples=True | |
| ) | |
| button_classification.click(classification, inputs=[species_classification, k_classification], outputs=[text_classification, image_classification]) | |
| button_masking.click(masking, inputs=[species_masking, k_masking], outputs=[text_masking, image_masking]) | |
| demo.launch() |