Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Dict, Any | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import json | |
| import os | |
| import requests | |
| from io import BytesIO | |
| from pyproj import Transformer | |
| import onnxruntime as ort | |
| from cryptography.fernet import Fernet | |
| from fastapi.responses import HTMLResponse | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows all origins | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allows all methods | |
| allow_headers=["*"], # Allows all headers | |
| ) | |
| # Model load | |
| key = os.getenv("CLASSIF_MODEL") | |
| cipher = Fernet(key) | |
| with open("model.bin", "rb") as f: | |
| bin_data = f.read() | |
| data = cipher.decrypt(bin_data) | |
| model = ort.InferenceSession(data) | |
| #model = ort.InferenceSession("model.onnx") | |
| with open("idx_to_target.json", "r") as f: | |
| idx_to_target = json.load(f) | |
| transformer = Transformer.from_crs("EPSG:4326", "EPSG:25832", always_xy=True) | |
| IMAGE_SIZE = 384 | |
| def normalize_image(image, | |
| mean=(0.485, 0.456, 0.406, 0.5, 0.5), | |
| std=(0.229, 00.224, 0.225, 0.5, 0.5)): | |
| image = (image / 255.0).astype("float32") | |
| for i in range(image.shape[2]): | |
| image[:, :, i] = (image[:, :, i] - mean[i]) / std[i] | |
| return image | |
| def pad_if_needed(image, target_size): | |
| height, width, _ = image.shape | |
| y0 = abs((height - target_size) // 2) | |
| x0 = abs((width - target_size) // 2) | |
| background = np.zeros((target_size, target_size, 5), dtype="uint8") | |
| background[y0:(y0 + height), x0:(x0 + width), :] = image | |
| return background | |
| def softmax(x): | |
| return np.exp(x) / np.sum(np.exp(x), axis=1) | |
| def get_image(coords, max_dim: int) -> Image: | |
| coords_utm = [transformer.transform(lon, lat) for lon, lat in coords] | |
| xs, ys = zip(*coords_utm) | |
| xmin, ymin, xmax, ymax = min(xs), min(ys), max(xs), max(ys) | |
| roi_width = xmax - xmin | |
| roi_height = ymax - ymin | |
| aspect_ratio = roi_width / roi_height | |
| if aspect_ratio > 1: | |
| width = max_dim | |
| height = int(max_dim / aspect_ratio) | |
| else: | |
| width = int(max_dim * aspect_ratio) | |
| height = max_dim | |
| # Construct WMS parameters | |
| wms_params = { | |
| 'username': os.getenv('WMSUSER'), | |
| 'password': os.getenv('WMSPW'), | |
| 'SERVICE': 'WMS', | |
| 'VERSION': '1.3.0', | |
| 'REQUEST': 'GetMap', | |
| 'BBOX': f"{xmin},{ymin},{xmax},{ymax}", | |
| 'CRS': 'EPSG:25832', | |
| 'WIDTH': width, | |
| 'HEIGHT': height, | |
| 'LAYERS': "geodanmark_2023_12_5cm", | |
| 'FORMAT': 'image/png', | |
| 'STYLES': '', | |
| 'DPI': 96, | |
| 'MAP_RESOLUTION': 96, | |
| 'FORMAT_OPTIONS': 'dpi:96' | |
| } | |
| # Down rgb image | |
| base_url = "https://services.datafordeler.dk/GeoDanmarkOrto/orto_foraar/1.0.0/WMS" | |
| try: | |
| response = requests.get(base_url, params=wms_params) | |
| response.raise_for_status() | |
| except requests.exceptions.HTTPError as err: | |
| print(err) | |
| return None | |
| img = Image.open(BytesIO(response.content)).convert("RGB") | |
| # Download terrain | |
| skygge_url = "https://services.datafordeler.dk/DHMNedboer/dhm/1.0.0/WMS" | |
| wms_params["LAYERS"] = "dhm_terraen_skyggekort" | |
| try: | |
| response = requests.get(skygge_url, params=wms_params) | |
| response.raise_for_status() | |
| except requests.exceptions.HTTPError as err: | |
| print(err) | |
| return None | |
| skygge_img = Image.open(BytesIO(response.content)).convert("L") | |
| # Create mask | |
| mask = Image.new('L', (width, height), 0) | |
| # Convert coordinates to image space | |
| x_norm = [(x - xmin) / roi_width for x in xs] | |
| y_norm = [(y - ymin) / roi_height for y in ys] | |
| x_img = [int(x * width) for x in x_norm] | |
| y_img = [int((1 - y) * height) for y in y_norm] | |
| # Draw polygon on mask | |
| ImageDraw.Draw(mask).polygon(list(zip(x_img, y_img)), outline=255, fill=255) | |
| array = np.concatenate([np.array(img), | |
| np.array(skygge_img)[:, :, np.newaxis], | |
| np.array(mask)[:, :, np.newaxis]], | |
| axis=2) | |
| return array | |
| def predict(image, image_size): | |
| image = pad_if_needed(image, image_size) | |
| image = normalize_image(image) | |
| image = np.transpose(image, (2, 0, 1)) | |
| image = image[np.newaxis] | |
| input_names = model.get_inputs()[0].name | |
| output_names = [output.name for output in model.get_outputs()] | |
| ort_inputs = {input_names: image} | |
| ort_outputs = model.run(None, ort_inputs) | |
| predictions = {name: softmax(output) for name, output in zip(output_names, ort_outputs)} | |
| return predictions | |
| pretty_target_name = { | |
| "hovednaturtype": "Hovednaturtype", | |
| "arealet_nbl": "Paragraf 3", | |
| "naturtilstand": "Naturtilstand", | |
| } | |
| def format_predictions(predictions): | |
| result_list = [] | |
| for target, logits in predictions.items(): | |
| # Get the index of the highest probability | |
| top_idx = np.argmax(logits[0]) | |
| # Get the probability value | |
| confidence = float(logits[0][top_idx]) | |
| # Get the class name from idx_to_target mapping | |
| class_name = idx_to_target[target][str(top_idx)] | |
| if target != "naturtilstand": | |
| class_name = class_name.capitalize() | |
| else: | |
| class_name = class_name.upper() | |
| target_name = pretty_target_name[target] | |
| # Format as concise HTML with class name and confidence percentage | |
| html_result = f"<div>{target_name}: <i>{class_name}</i> ({confidence:.1%})</div>" | |
| result_list.append(html_result) | |
| return "".join(result_list) | |
| class GeoJSONInput(BaseModel): | |
| geojson: Dict[str, Any] | |
| class ResultOutput(BaseModel): | |
| result: str | |
| async def get_html(): | |
| html_file = "index.html" | |
| with open(html_file, "r") as f: | |
| content = f.read() | |
| return HTMLResponse(content=content) | |
| async def predict_endpoint(geojson_input: GeoJSONInput) -> ResultOutput: | |
| try: | |
| coords = geojson_input.geojson['geometry']['coordinates'][0] | |
| image = get_image(coords, IMAGE_SIZE) | |
| predictions = predict(image, IMAGE_SIZE) | |
| result = format_predictions(predictions) | |
| return ResultOutput(result = result) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |