| from fastapi import FastAPI |
| from pydantic import BaseModel, Field |
| import numpy as np |
| import onnxruntime as ort |
| from typing_extensions import Annotated |
| import gradio as gr |
| from cryptography.fernet import Fernet |
| import os |
| import pickle as pkl |
|
|
| |
| key = os.getenv("ONNX_KEY") |
| cipher = Fernet(key) |
|
|
| VERSION = "0.0.3" |
| TITLE = f"DVPI beregnings API (version {VERSION})" |
| DESCRIPTION = "Beregn Dansk Vandløbs Plante Indeks (DVPI) fra dækningsgrad af plantearter. Beregningen er baseret på en model som efterligner DVPI beregningsmetoden og er dermed ikke eksakt, usikkerheden er i gennemsnit **±0.017 EQR-enheder** og **R<sup>2</sup>=0.98** når den sammenlignes med den originale. Kan der ikke beregnes en værdi, returneres EQR=0 og DVPI=0." |
| URL = "https://kennethtm-dvpi.hf.space" |
|
|
| |
| with open("model_v3.bin", "rb") as f: |
| encrypted = f.read() |
| decrypted = cipher.decrypt(encrypted) |
| ort_session = ort.InferenceSession(decrypted) |
|
|
| |
| with open("metadata_v3.bin", "rb") as f: |
| encrypted = f.read() |
| decrypted = cipher.decrypt(encrypted) |
| metadata = pkl.loads(decrypted) |
|
|
| latinname2stancode = metadata["latinname2stancode"] |
| valid_taxacodes = metadata["valid_taxacodes"] |
| normalizer_1 = metadata["normalizer_1"] |
| normalizer_2 = metadata["normalizer_2"] |
| taxacode2idx = metadata["taxacode2idx"] |
|
|
| |
| def preprocess_species(species: dict[int: float]) -> dict[int: float]: |
| |
| intermediate_species = {} |
| for sccode, value in species.items(): |
| if sccode in normalizer_1: |
| new_sccode = normalizer_1[sccode] |
| if new_sccode in intermediate_species: |
| intermediate_species[new_sccode] += value |
| else: |
| intermediate_species[new_sccode] = value |
| |
| |
| final_species = {} |
| for sccode, value in intermediate_species.items(): |
| if sccode in normalizer_2: |
| if normalizer_2[sccode] is not None: |
| new_sccode = normalizer_2[sccode] |
| if new_sccode in final_species: |
| final_species[new_sccode] += value |
| else: |
| final_species[new_sccode] = value |
| else: |
| final_species[sccode] = value |
|
|
| |
| final_species = {taxacode: value for taxacode, value in final_species.items() if taxacode in valid_taxacodes} |
| |
| return final_species |
|
|
|
|
| class SpeciesCover(BaseModel): |
| species: dict[int, Annotated[float, Field(ge=0, le=100)]] |
| |
| model_config = { |
| "json_schema_extra": { |
| "examples": [{ |
| "species": { |
| 6458: 25.0, |
| 4158: 15.5, |
| 7208: 10.0 |
| } |
| }] |
| } |
| } |
|
|
| class EQRResult(BaseModel): |
| EQR: float |
| DVPI: int |
| version: str = VERSION |
|
|
| |
| app = FastAPI(title=TITLE, |
| description=DESCRIPTION) |
|
|
| def eqr_to_dvpi(eqr: float) -> int: |
| if eqr < 0.20: |
| return 1 |
| elif eqr < 0.35: |
| return 2 |
| elif eqr < 0.50: |
| return 3 |
| elif eqr < 0.70: |
| return 4 |
| else: |
| return 5 |
|
|
|
|
| |
| @app.post("/dvpi") |
| def predict(cover_data: SpeciesCover) -> EQRResult: |
| """Predict EQR and DVPI from species cover data""" |
|
|
| species_preproc = preprocess_species(cover_data.species) |
|
|
| input_vector = np.zeros((1, len(valid_taxacodes))) |
|
|
| for species, cover in species_preproc.items(): |
| idx = taxacode2idx[species] |
| input_vector[0, idx] = cover |
|
|
| if np.sum(input_vector) == 0: |
| return EQRResult(EQR=0, DVPI=0) |
| |
| input_name = ort_session.get_inputs()[0].name |
| ort_inputs = {input_name: input_vector.astype(np.float32)} |
| _, output_2 = ort_session.run(None, ort_inputs) |
|
|
| eqr = float(output_2[0][0]) |
| eqr = 1 if eqr > 1 else eqr |
| dvpi = eqr_to_dvpi(eqr) |
| |
| return EQRResult(EQR=round(eqr, 3), DVPI=dvpi) |
|
|
| |
| def add_entry(species, cover, current_dict) -> tuple[dict, str]: |
| |
| current_dict[species] = cover |
| return current_dict, current_dict |
|
|
| def gradio_predict(cover_data: dict): |
|
|
| if len(cover_data) == 0: |
| return {} |
| |
| cover_data_code = {latinname2stancode[species]: cover for species, cover in cover_data.items()} |
|
|
| data = SpeciesCover(species=cover_data_code) |
| result = predict(data) |
|
|
| return result.model_dump() |
| |
| with gr.Blocks() as io: |
|
|
| gr.Markdown(f"# {TITLE}") |
| gr.Markdown(DESCRIPTION) |
|
|
| with gr.Tab(label = "Beregner"): |
|
|
| gr.Markdown("Beregning er baseret på samfund af plantearter og deres dækningsgrad. Når API'et bruges anvendes arternes [Stancode](https://dce.au.dk/overvaagning/stancode/stancodelister) (SC1064) - se 'Dokumentation' for eksempel på brug.") |
|
|
| current_dict = gr.State({}) |
| |
| with gr.Row(): |
| species_choices = sorted(list(latinname2stancode.keys())) |
| species_input = gr.Dropdown(choices=species_choices, label="Vælg art") |
| cover_input = gr.Number(label="Dækningsgrad (%)", minimum=0, maximum=100) |
| |
| with gr.Row(): |
| add_btn = gr.Button("Tilføj") |
| reset_btn = gr.Button("Nulstil") |
| |
| list_display = gr.JSON(label="Artsliste") |
| |
| calc_btn = gr.Button("Beregn") |
| results = gr.JSON(label="Resultater") |
| |
| def reset_dict(): |
| return {}, {}, {} |
| |
| add_btn.click( |
| add_entry, |
| inputs=[species_input, cover_input, current_dict], |
| outputs=[current_dict, list_display], |
| show_api=False |
| ) |
| |
| reset_btn.click( |
| reset_dict, |
| inputs=[], |
| outputs=[current_dict, list_display, results], |
| show_api=False |
| ) |
| |
| calc_btn.click( |
| gradio_predict, |
| inputs=[current_dict], |
| outputs=results, |
| show_api=False |
| ) |
|
|
| gr.Markdown("App og model af Kenneth Thorø Martinsen (kenneth2810@gmail.com).") |
|
|
| with gr.Tab(label="Dokumentation"): |
|
|
| gr.Markdown("## Eksempel på brug af API") |
| gr.Markdown(f"API dokumentation kan findes på [{URL}/docs]({URL}/docs)") |
| gr.Markdown("### Python") |
| gr.Code(f""" |
| import requests |
| import json |
| |
| data = {{ |
| "species": {{ |
| 6458: 25.0, |
| 4158: 15.5, |
| 7208: 10.0 |
| }} |
| }} |
| |
| response = requests.post("{URL}/dvpi", json=data) |
| print(response.json()) |
| """) |
|
|
| gr.Markdown("### R") |
| gr.Code(f""" |
| library(httr) |
| library(jsonlite) |
| |
| data <- list(species = list( |
| 6458 = 25.0, |
| 4158 = 15.5, |
| 7208 = 10.0 |
| )) |
| |
| response <- POST("{URL}/dvpi", |
| body = toJSON(data, auto_unbox = TRUE), |
| content_type("application/json")) |
| |
| print(fromJSON(rawToChar(response$content))) |
| """) |
| |
| |
| app = gr.mount_gradio_app(app, io, path="/") |
|
|