Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import requests | |
| import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| import onnxruntime as ort | |
| from .preprocessing import preprocess_image | |
| class HeightPredictor: | |
| def __init__(self, model_path: str): | |
| model_path = Path(model_path) | |
| if not model_path.exists(): | |
| raise FileNotFoundError(f'ONNX model not found at {model_path}') | |
| sess_options = ort.SessionOptions() | |
| sess_options.enable_cpu_mem_arena = True | |
| self.session = ort.InferenceSession( | |
| model_path.as_posix(), | |
| sess_options, | |
| providers=['CPUExecutionProvider'] | |
| ) | |
| self.input_name = self.session.get_inputs()[0].name | |
| self.output_name = self.session.get_outputs()[0].name | |
| api_key = os.environ.get('GOOGLE_MAPS_API_KEY', '') | |
| if not api_key: | |
| raise RuntimeError('Google Maps API key environment variable is not set') | |
| self.api_key = api_key | |
| def download_satellite_image( | |
| self, | |
| latitude, | |
| longitude, | |
| zoom: int = 20, | |
| size: str = '800x800', | |
| ): | |
| url = ( | |
| 'https://maps.googleapis.com/maps/api/staticmap' | |
| f'?center={latitude},{longitude}' | |
| f'&zoom={zoom}' | |
| f'&size={size}' | |
| '&maptype=satellite' | |
| f'&key={self.api_key}' | |
| ) | |
| response = requests.get(url, timeout=10) | |
| if response.status_code != 200: | |
| raise RuntimeError( | |
| f'Failed to download image - {response.status_code}: {response.text}' | |
| ) | |
| image = Image.open(io.BytesIO(response.content)).convert('RGB') | |
| return image | |
| def predict_height(self, image: Image.Image) -> float: | |
| inp = preprocess_image(image, img_size=224, fraction=0.7) | |
| inp = inp.astype('float32') | |
| outputs = self.session.run( | |
| [self.output_name], | |
| {self.input_name: inp} | |
| ) | |
| height = float(np.squeeze(outputs[0]).item()) | |
| height = max(0.0, height) | |
| return height | |
| def predict_from_coordinates(self, latitude: float, longitude: float) -> dict: | |
| try: | |
| image = self.download_satellite_image(latitude, longitude) | |
| height = self.predict_height(image) | |
| return { | |
| 'status': 'success', | |
| 'latitude': latitude, | |
| 'longitude': longitude, | |
| 'predicted_height': round(height, 2), | |
| } | |
| except Exception as e: | |
| return { | |
| 'status': 'error', | |
| 'latitude': latitude, | |
| 'longitude': longitude, | |
| 'predicted_height': None, | |
| 'error': str(e), | |
| } | |
| _predictor = None | |
| def get_predictor() -> HeightPredictor: | |
| global _predictor | |
| if _predictor is None: | |
| model_path = Path(__file__).parent.parent / 'models' / 'height_predictor.onnx' | |
| if not model_path.exists(): | |
| raise FileNotFoundError(f'ONNX model not found at {model_path}') | |
| _predictor = HeightPredictor(str(model_path)) | |
| return _predictor | |