import os import io import requests import numpy as np from PIL import Image from pathlib import Path import onnxruntime as ort from .preprocessing import preprocess_image class HeightPredictor: def __init__(self, model_path: str): model_path = Path(model_path) if not model_path.exists(): raise FileNotFoundError(f'ONNX model not found at {model_path}') sess_options = ort.SessionOptions() sess_options.enable_cpu_mem_arena = True self.session = ort.InferenceSession( model_path.as_posix(), sess_options, providers=['CPUExecutionProvider'] ) self.input_name = self.session.get_inputs()[0].name self.output_name = self.session.get_outputs()[0].name api_key = os.environ.get('GOOGLE_MAPS_API_KEY', '') if not api_key: raise RuntimeError('Google Maps API key environment variable is not set') self.api_key = api_key def download_satellite_image( self, latitude, longitude, zoom: int = 20, size: str = '800x800', ): url = ( 'https://maps.googleapis.com/maps/api/staticmap' f'?center={latitude},{longitude}' f'&zoom={zoom}' f'&size={size}' '&maptype=satellite' f'&key={self.api_key}' ) response = requests.get(url, timeout=10) if response.status_code != 200: raise RuntimeError( f'Failed to download image - {response.status_code}: {response.text}' ) image = Image.open(io.BytesIO(response.content)).convert('RGB') return image def predict_height(self, image: Image.Image) -> float: inp = preprocess_image(image, img_size=224, fraction=0.7) inp = inp.astype('float32') outputs = self.session.run( [self.output_name], {self.input_name: inp} ) height = float(np.squeeze(outputs[0]).item()) height = max(0.0, height) return height def predict_from_coordinates(self, latitude: float, longitude: float) -> dict: try: image = self.download_satellite_image(latitude, longitude) height = self.predict_height(image) return { 'status': 'success', 'latitude': latitude, 'longitude': longitude, 'predicted_height': round(height, 2), } except Exception as e: return { 'status': 'error', 'latitude': latitude, 'longitude': longitude, 'predicted_height': None, 'error': str(e), } _predictor = None def get_predictor() -> HeightPredictor: global _predictor if _predictor is None: model_path = Path(__file__).parent.parent / 'models' / 'height_predictor.onnx' if not model_path.exists(): raise FileNotFoundError(f'ONNX model not found at {model_path}') _predictor = HeightPredictor(str(model_path)) return _predictor