import gradio as gr from transformers import pipeline import requests import pygbif from bs4 import BeautifulSoup import pandas as pd classification_model = pipeline("text-classification", model="models/text_classification_model", tokenizer="models/text_classification_model", top_k=5) masking_model = pipeline("fill-mask", model="models/fill_mask_model", tokenizer="models/fill_mask_model", top_k=100) eunis_habitats = pd.read_excel('data/eunis_habitats.xlsx') image_not_found = gr.Image("https://img.freepik.com/premium-vector/file-folder-mascot-character-design-vector_166742-4413.jpg") def return_image(task, label): image = None if task == "classification": floraveg_url = f"https://floraveg.eu/habitat/overview/{label}" floraveg_tag = "https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/" elif task == "masking": floraveg_url = f"https://floraveg.eu/taxon/overview/{label}" floraveg_tag = "https://files.ibot.cas.cz/cevs/images/taxa/large/" response = requests.get(floraveg_url) if response.status_code == 200: soup = BeautifulSoup(response.text, 'html.parser') img_tag = soup.find('img', src=lambda x: x and x.startswith(floraveg_tag)) if img_tag: image_url = img_tag['src'] image = gr.Image(value=image_url) return image def gbif_normalization(text): all_species = text.split(',') all_species = [species.strip() for species in all_species] species_gbif = [] for species in all_species: gbif_match_result = pygbif.species.name_backbone(species, taxonRank="SPECIES") if 'usage' in gbif_match_result: r = gbif_match_result["usage"]["canonicalName"] else: r = species species_gbif.append(r) text = ", ".join(species_gbif) text = text.lower() return text def classification(text, k): text = gbif_normalization(text) result = classification_model(text) habitat_labels = [res['label'] for res in result[0][:k]] habitat_name = eunis_habitats[eunis_habitats['EUNIS 2020 code'] == habitat_labels[0]]['EUNIS-2021 habitat name'].values[0] if k == 1: text = f"This vegetation plot probably belongs to the habitat type {habitat_labels[0]}." text += f"\nThis habitat type is named '{habitat_name}'." elif k == 2: text = f"This vegetation plot probably belongs to the habitat type {', '.join(habitat_labels[:-1])} or {habitat_labels[-1]}." text += f"\nThe most likely habitat type (i.e., {habitat_labels[0]}) is named '{habitat_name}'." else: text = f"This vegetation plot probably belongs to the habitat type {', '.join(habitat_labels[:-1])}, or {habitat_labels[-1]}." text += f"\nThe most likely habitat type (i.e., {habitat_labels[0]}) is named '{habitat_name}'." image = return_image("classification", habitat_labels[0]) if image is not None: text += "\nBelow is an example of this habitat type taken from the website FloraVEG." else: text += f"\nNo image found for this habitat type." image = image_not_found return text, image def masking(text, k): text = gbif_normalization(text) text_split = text.split(', ') best_predictions = [] for _ in range(k): max_score = 0 best_prediction = None best_position = None best_sentence = None for i in range(len(text_split) + 1): masked_text = ', '.join(text_split[:i] + ['[MASK]'] + text_split[i:]) j = 0 while True: prediction = masking_model(masked_text)[j] species = prediction['token_str'] if species in text_split or species in best_predictions: j += 1 else: break score = prediction['score'] sentence = prediction['sequence'] if score > max_score: max_score = score best_prediction = species best_position = i best_sentence = sentence best_predictions.append(best_prediction) text_split.insert(best_position, best_prediction) best_positions = [text_split.index(prediction) for prediction in best_predictions] best_predictions = [s.strip().capitalize() for s in best_predictions] best_sentence = ", ".join([s.strip().capitalize() for s in best_sentence.split(",")]) if k == 1: text = f"The most likely missing species is '{best_predictions[0]}' (position {best_positions[0]})." elif k == 2: text = f"The most likely missing species are '{best_predictions[0]}' and '{best_predictions[1]}' (positions {best_positions[0]} and {best_positions[1]})." else: text = f"The most likely missing species are " + ', '.join(f"'{s}'" for s in best_predictions[:-1]) + f", and '{best_predictions[-1]}' (positions {', '.join(map(str, best_positions[:-1]))}, and {best_positions[-1]})." text += f"\nThe completed vegetation plot is thus '{best_sentence}'." image = return_image("masking", best_predictions[0]) if image is not None: text += f"\nBelow is an image of the first missing species (i.e., {best_predictions[0]}) taken from the website FloraVEG." else: text += f"\nNo image found for the first missing species (i.e., {best_predictions[0]})." image = image_not_found return text, image with gr.Blocks() as demo: gr.Markdown("""

Pl@ntBERT

""") with gr.Tab("Vegetation plot classification"): gr.Markdown("""

Habitat identification of vegetation plots!

""") with gr.Row(): with gr.Column(): species_classification = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.") k_classification = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of habitat types to display.") with gr.Column(): text_classification = gr.Textbox(label="Prediction") image_classification = gr.Image() button_classification = gr.Button("Classify") gr.Markdown("""
An example of input
""") gr.Examples( examples=[ ["phragmites australis, lemna minor, typha latifolia", 3], ["fagus sylvatica, prenanthes purpurea, abies alba, cardamine bulbifera, orthilia secunda, oxalis acetosella, rubus idaeus", 1] ], inputs=[species_classification, k_classification], outputs=[text_classification, image_classification], fn=classification, cache_examples=True) with gr.Tab("Missing species finding"): gr.Markdown("""

Missing vascular plant species retrieval!

""") with gr.Row(): with gr.Column(): species_masking = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.") k_masking = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of missing species to find.") with gr.Column(): text_masking = gr.Textbox(label="Prediction") image_masking = gr.Image() button_masking = gr.Button("Find") gr.Markdown("""
An example of input
""") gr.Examples( examples=[ ["calamagrostis arenaria, medicago marina, pancratium maritimum, thinopyrum junceum", 1], ["trapa natans, lemna minor, phragmites australis, sparganium erectum", 1] ], inputs=[species_masking, k_masking], outputs=[text_masking, image_masking], fn=masking, cache_examples=True ) button_classification.click(classification, inputs=[species_classification, k_classification], outputs=[text_classification, image_classification]) button_masking.click(masking, inputs=[species_masking, k_masking], outputs=[text_masking, image_masking]) demo.launch()