Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Dict, Any | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import json | |
| import os | |
| import requests | |
| from io import BytesIO | |
| from pyproj import Transformer | |
| import onnxruntime as ort | |
| from cryptography.fernet import Fernet | |
| from fastapi.responses import HTMLResponse | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows all origins | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allows all methods | |
| allow_headers=["*"], # Allows all headers | |
| ) | |
| # Model load | |
| key = os.getenv("MODEL_KEY") | |
| cipher = Fernet(key) | |
| with open("species_features.bin", "rb") as f: | |
| bin_data = f.read() | |
| data = cipher.decrypt(bin_data) | |
| species_features = np.load(BytesIO(data)) | |
| with open("id2spec.bin", "rb") as f: | |
| bin_data = f.read() | |
| data = cipher.decrypt(bin_data) | |
| id2spec = json.loads(data) | |
| with open("image_encoder.bin", "rb") as f: | |
| bin_data = f.read() | |
| data = cipher.decrypt(bin_data) | |
| image_encoder = ort.InferenceSession(data) | |
| with open("spec2key.json", "r") as f: | |
| spec2key = json.load(f) | |
| transformer = Transformer.from_crs("EPSG:4326", "EPSG:25832", always_xy=True) | |
| IMAGE_SIZE = 384 | |
| def normalize_image(image, mean=(0.485, 0.456, 0.406), std=(0.229, 00.224, 0.225)): | |
| image = (image / 255.0).astype("float32") | |
| image[:, :, 0] = (image[:, :, 0] - mean[0]) / std[0] | |
| image[:, :, 1] = (image[:, :, 1] - mean[1]) / std[1] | |
| image[:, :, 2] = (image[:, :, 2] - mean[2]) / std[2] | |
| return image | |
| def pad_if_needed(image, target_size): | |
| height, width, _ = image.shape | |
| y0 = abs((height - target_size) // 2) | |
| x0 = abs((width - target_size) // 2) | |
| background = np.zeros((target_size, target_size, 3), dtype="uint8") | |
| background[y0:(y0 + height), x0:(x0 + width), :] = image | |
| return background | |
| def predict(image, image_size, top_k = 20): | |
| image = image.convert("RGB") | |
| image = np.array(image) | |
| image = pad_if_needed(image, image_size) | |
| image = normalize_image(image) | |
| image = np.transpose(image, (2, 0, 1)) | |
| image = image[np.newaxis] | |
| image_features = image_encoder.run(None, {"input.1": image})[0] | |
| similarity = np.dot(image_features, species_features.T) | |
| sorted_similarity = np.argsort(similarity[0])[::-1][:top_k] | |
| species_scores = {id2spec[str(idx)]: similarity[0, idx] for idx in sorted_similarity} | |
| species_scores = {species: (float(score)+1)/2*100 for species, score in species_scores.items()} | |
| return species_scores | |
| def format_predictions(predictions): | |
| baseurl = "https://www.gbif.org/species/" | |
| formatted_strings = [] | |
| for species, value in predictions.items(): | |
| gbif_key = spec2key.get(species) | |
| value = round(value, 1) | |
| if gbif_key is None: | |
| formatted_strings.append(f"{species}: {value}%") | |
| else: | |
| formatted_strings.append(f'<a href="{baseurl}{gbif_key}" target="_blank">{species}</a>: {value}%') | |
| format_predictions = "<br>".join(formatted_strings) | |
| return format_predictions | |
| def get_image(coords, max_dim): | |
| coords_utm = [transformer.transform(lon, lat) for lon, lat in coords] | |
| xs, ys = zip(*coords_utm) | |
| xmin, ymin, xmax, ymax = min(xs), min(ys), max(xs), max(ys) | |
| roi_width = xmax - xmin | |
| roi_height = ymax - ymin | |
| aspect_ratio = roi_width / roi_height | |
| if aspect_ratio > 1: | |
| width = max_dim | |
| height = int(max_dim / aspect_ratio) | |
| else: | |
| width = int(max_dim * aspect_ratio) | |
| height = max_dim | |
| wms_params = { | |
| 'username': os.getenv('WMSUSER'), | |
| 'password': os.getenv('WMSPW'), | |
| 'SERVICE': 'WMS', | |
| 'VERSION': '1.3.0', | |
| 'REQUEST': 'GetMap', | |
| 'BBOX': f"{xmin},{ymin},{xmax},{ymax}", | |
| 'CRS': 'EPSG:25832', | |
| 'WIDTH': width, | |
| 'HEIGHT': height, | |
| 'LAYERS': 'orto_foraar', | |
| 'STYLES': '', | |
| 'FORMAT': 'image/png', | |
| 'DPI': 96, | |
| 'MAP_RESOLUTION': 96, | |
| 'FORMAT_OPTIONS': 'dpi:96' | |
| } | |
| base_url = "https://services.datafordeler.dk/GeoDanmarkOrto/orto_foraar/1.0.0/WMS" | |
| response = requests.get(base_url, params=wms_params) | |
| if response.status_code != 200: | |
| raise HTTPException(status_code=500, detail=f"Error fetching image: {response.status_code}") | |
| img = Image.open(BytesIO(response.content)) | |
| mask = Image.new('L', (width, height), 0) | |
| x_norm = [(x - xmin) / roi_width for x in xs] | |
| y_norm = [(y - ymin) / roi_height for y in ys] | |
| x_img = [int(x * width) for x in x_norm] | |
| y_img = [int((1 - y) * height) for y in y_norm] | |
| ImageDraw.Draw(mask).polygon(list(zip(x_img, y_img)), outline=255, fill=255) | |
| masked_img = Image.new('RGB', img.size) | |
| masked_img.paste(img, mask=mask) | |
| return masked_img | |
| class GeoJSONInput(BaseModel): | |
| geojson: Dict[str, Any] | |
| async def get_html(): | |
| html_file = "index.html" | |
| with open(html_file, "r") as f: | |
| content = f.read() | |
| return HTMLResponse(content=content) | |
| async def predict_endpoint(geojson_input: GeoJSONInput): | |
| try: | |
| coords = geojson_input.geojson['geometry']['coordinates'][0] | |
| image = get_image(coords, IMAGE_SIZE) | |
| predictions = predict(image, IMAGE_SIZE) | |
| predictions_formatted = format_predictions(predictions) | |
| return {"predictions": predictions, "predictions_formatted": predictions_formatted} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) |