File size: 3,144 Bytes
9ad3152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import io
import requests
import numpy as np
from PIL import Image
from pathlib import Path
import onnxruntime as ort

from .preprocessing import preprocess_image


class HeightPredictor:
    def __init__(self, model_path: str):
        model_path = Path(model_path)

        if not model_path.exists():
            raise FileNotFoundError(f'ONNX model not found at {model_path}')

        sess_options = ort.SessionOptions()
        sess_options.enable_cpu_mem_arena = True

        self.session = ort.InferenceSession(
            model_path.as_posix(),
            sess_options,
            providers=['CPUExecutionProvider']
        )

        self.input_name = self.session.get_inputs()[0].name
        self.output_name = self.session.get_outputs()[0].name
        api_key = os.environ.get('GOOGLE_MAPS_API_KEY', '')
        if not api_key:
            raise RuntimeError('Google Maps API key environment variable is not set')
        self.api_key = api_key

    def download_satellite_image(
        self,
        latitude,
        longitude,
        zoom: int = 20,
        size: str = '800x800',
    ):
        url = (
            'https://maps.googleapis.com/maps/api/staticmap'
            f'?center={latitude},{longitude}'
            f'&zoom={zoom}'
            f'&size={size}'
            '&maptype=satellite'
            f'&key={self.api_key}'
        )

        response = requests.get(url, timeout=10)
        if response.status_code != 200:
            raise RuntimeError(
                f'Failed to download image - {response.status_code}: {response.text}'
            )

        image = Image.open(io.BytesIO(response.content)).convert('RGB')
        return image

    def predict_height(self, image: Image.Image) -> float:
        inp = preprocess_image(image, img_size=224, fraction=0.7)
        inp = inp.astype('float32')

        outputs = self.session.run(
            [self.output_name],
            {self.input_name: inp}
        )

        height = float(np.squeeze(outputs[0]).item())
        height = max(0.0, height)
        return height

    def predict_from_coordinates(self, latitude: float, longitude: float) -> dict:
        try:
            image = self.download_satellite_image(latitude, longitude)
            height = self.predict_height(image)

            return {
                'status': 'success',
                'latitude': latitude,
                'longitude': longitude,
                'predicted_height': round(height, 2),
            }
        except Exception as e:
            return {
                'status': 'error',
                'latitude': latitude,
                'longitude': longitude,
                'predicted_height': None,
                'error': str(e),
            }


_predictor = None

def get_predictor() -> HeightPredictor:
    global _predictor

    if _predictor is None:
        model_path = Path(__file__).parent.parent / 'models' / 'height_predictor.onnx'
        if not model_path.exists():
            raise FileNotFoundError(f'ONNX model not found at {model_path}')
        _predictor = HeightPredictor(str(model_path))

    return _predictor