Spaces:
Sleeping
Sleeping
File size: 3,144 Bytes
9ad3152 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import os
import io
import requests
import numpy as np
from PIL import Image
from pathlib import Path
import onnxruntime as ort
from .preprocessing import preprocess_image
class HeightPredictor:
def __init__(self, model_path: str):
model_path = Path(model_path)
if not model_path.exists():
raise FileNotFoundError(f'ONNX model not found at {model_path}')
sess_options = ort.SessionOptions()
sess_options.enable_cpu_mem_arena = True
self.session = ort.InferenceSession(
model_path.as_posix(),
sess_options,
providers=['CPUExecutionProvider']
)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
api_key = os.environ.get('GOOGLE_MAPS_API_KEY', '')
if not api_key:
raise RuntimeError('Google Maps API key environment variable is not set')
self.api_key = api_key
def download_satellite_image(
self,
latitude,
longitude,
zoom: int = 20,
size: str = '800x800',
):
url = (
'https://maps.googleapis.com/maps/api/staticmap'
f'?center={latitude},{longitude}'
f'&zoom={zoom}'
f'&size={size}'
'&maptype=satellite'
f'&key={self.api_key}'
)
response = requests.get(url, timeout=10)
if response.status_code != 200:
raise RuntimeError(
f'Failed to download image - {response.status_code}: {response.text}'
)
image = Image.open(io.BytesIO(response.content)).convert('RGB')
return image
def predict_height(self, image: Image.Image) -> float:
inp = preprocess_image(image, img_size=224, fraction=0.7)
inp = inp.astype('float32')
outputs = self.session.run(
[self.output_name],
{self.input_name: inp}
)
height = float(np.squeeze(outputs[0]).item())
height = max(0.0, height)
return height
def predict_from_coordinates(self, latitude: float, longitude: float) -> dict:
try:
image = self.download_satellite_image(latitude, longitude)
height = self.predict_height(image)
return {
'status': 'success',
'latitude': latitude,
'longitude': longitude,
'predicted_height': round(height, 2),
}
except Exception as e:
return {
'status': 'error',
'latitude': latitude,
'longitude': longitude,
'predicted_height': None,
'error': str(e),
}
_predictor = None
def get_predictor() -> HeightPredictor:
global _predictor
if _predictor is None:
model_path = Path(__file__).parent.parent / 'models' / 'height_predictor.onnx'
if not model_path.exists():
raise FileNotFoundError(f'ONNX model not found at {model_path}')
_predictor = HeightPredictor(str(model_path))
return _predictor
|