adema5051's picture
Upload 25 files
9ad3152 verified
import os
import io
import requests
import numpy as np
from PIL import Image
from pathlib import Path
import onnxruntime as ort
from .preprocessing import preprocess_image
class HeightPredictor:
def __init__(self, model_path: str):
model_path = Path(model_path)
if not model_path.exists():
raise FileNotFoundError(f'ONNX model not found at {model_path}')
sess_options = ort.SessionOptions()
sess_options.enable_cpu_mem_arena = True
self.session = ort.InferenceSession(
model_path.as_posix(),
sess_options,
providers=['CPUExecutionProvider']
)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
api_key = os.environ.get('GOOGLE_MAPS_API_KEY', '')
if not api_key:
raise RuntimeError('Google Maps API key environment variable is not set')
self.api_key = api_key
def download_satellite_image(
self,
latitude,
longitude,
zoom: int = 20,
size: str = '800x800',
):
url = (
'https://maps.googleapis.com/maps/api/staticmap'
f'?center={latitude},{longitude}'
f'&zoom={zoom}'
f'&size={size}'
'&maptype=satellite'
f'&key={self.api_key}'
)
response = requests.get(url, timeout=10)
if response.status_code != 200:
raise RuntimeError(
f'Failed to download image - {response.status_code}: {response.text}'
)
image = Image.open(io.BytesIO(response.content)).convert('RGB')
return image
def predict_height(self, image: Image.Image) -> float:
inp = preprocess_image(image, img_size=224, fraction=0.7)
inp = inp.astype('float32')
outputs = self.session.run(
[self.output_name],
{self.input_name: inp}
)
height = float(np.squeeze(outputs[0]).item())
height = max(0.0, height)
return height
def predict_from_coordinates(self, latitude: float, longitude: float) -> dict:
try:
image = self.download_satellite_image(latitude, longitude)
height = self.predict_height(image)
return {
'status': 'success',
'latitude': latitude,
'longitude': longitude,
'predicted_height': round(height, 2),
}
except Exception as e:
return {
'status': 'error',
'latitude': latitude,
'longitude': longitude,
'predicted_height': None,
'error': str(e),
}
_predictor = None
def get_predictor() -> HeightPredictor:
global _predictor
if _predictor is None:
model_path = Path(__file__).parent.parent / 'models' / 'height_predictor.onnx'
if not model_path.exists():
raise FileNotFoundError(f'ONNX model not found at {model_path}')
_predictor = HeightPredictor(str(model_path))
return _predictor