philip-singer commited on
Commit
1e2e161
·
verified ·
1 Parent(s): 8714a60

Upload predict.py

Browse files
Files changed (1) hide show
  1. predict.py +112 -0
predict.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import pandas as pd
5
+ import joblib
6
+ from flask import Blueprint, request, jsonify
7
+ import os
8
+
9
+ predict_bp = Blueprint('predict', __name__)
10
+
11
+ # Lazy load model and preprocessors
12
+ model_bundle = None
13
+ model = None
14
+ preprocessor = None
15
+ nn_model = None
16
+ beach_db = None
17
+ wind_directions = None
18
+
19
+ class ImprovedTrashPredictorMLP(nn.Module):
20
+ def __init__(self, input_size, hidden_sizes, output_size, dropout_rate=0.3):
21
+ super().__init__()
22
+ layers = []
23
+ in_features = input_size
24
+ for hidden_size in hidden_sizes:
25
+ layers.append(nn.Linear(in_features, hidden_size))
26
+ layers.append(nn.BatchNorm1d(hidden_size))
27
+ layers.append(nn.ReLU())
28
+ layers.append(nn.Dropout(dropout_rate))
29
+ in_features = hidden_size
30
+ layers.append(nn.Linear(in_features, output_size))
31
+ self.model = nn.Sequential(*layers)
32
+ def forward(self, x):
33
+ return self.model(x)
34
+
35
+ def load_bundle(bundle_path):
36
+ return {
37
+ "preprocessor": joblib.load(os.path.join(bundle_path, "preprocessor.pkl")),
38
+ "beach_db": joblib.load(os.path.join(bundle_path, "beach_db.pkl")),
39
+ "nn_model": joblib.load(os.path.join(bundle_path, "nn_model.pkl")),
40
+ "wind_directions": joblib.load(os.path.join(bundle_path, "wind_directions.pkl")),
41
+ }
42
+
43
+ def lazy_load():
44
+ global model_bundle, model, preprocessor, nn_model, beach_db, wind_directions
45
+ if model_bundle is None:
46
+ bundle_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../saved_models/final_bundle_20250703_112021'))
47
+ model_bundle = load_bundle(bundle_path)
48
+ preprocessor = model_bundle["preprocessor"]
49
+ beach_db = model_bundle["beach_db"]
50
+ nn_model = model_bundle["nn_model"]
51
+ wind_directions = model_bundle["wind_directions"]
52
+ # Model init
53
+ input_size = preprocessor.transformers_[0][1].steps[0][1].get_feature_names_out().shape[0] + len(preprocessor.transformers_[1][2])
54
+ model_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../saved_models/final_model_20250703_112021.pth'))
55
+ model_instance = ImprovedTrashPredictorMLP(
56
+ input_size=input_size,
57
+ hidden_sizes=[256, 128, 64, 32],
58
+ output_size=1,
59
+ dropout_rate=0.3
60
+ )
61
+ model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
62
+ model_instance.eval()
63
+ model = model_instance
64
+
65
+ @predict_bp.route('/predict', methods=['POST'])
66
+ def predict():
67
+ lazy_load()
68
+ data = request.get_json()
69
+ latitude = float(data.get('latitude'))
70
+ longitude = float(data.get('longitude'))
71
+ wind_dir = data.get('wind_direction', wind_directions[0]).upper()
72
+ wind_str = float(data.get('wind_strength', 5))
73
+
74
+ # Find nearest beach
75
+ query_point = np.array([[latitude, longitude]])
76
+ _, indices = nn_model.kneighbors(query_point)
77
+ beach_index = indices[0][0]
78
+ nearest_beach = beach_db.iloc[beach_index]
79
+
80
+ # Create input
81
+ input_data = pd.DataFrame({
82
+ 'Orientation': [nearest_beach['Orientation']],
83
+ 'Sediment': [nearest_beach['Sediment']],
84
+ 'Longitude': [longitude],
85
+ 'Latitude': [latitude],
86
+ 'Wind direction': [wind_dir],
87
+ 'Wind strength': [wind_str]
88
+ })
89
+
90
+ # Preprocess and predict
91
+ processed_input = preprocessor.transform(input_data)
92
+ input_tensor = torch.tensor(processed_input, dtype=torch.float32)
93
+ with torch.no_grad():
94
+ prediction = model(input_tensor).item()
95
+
96
+ # Get nearest beach details for the response (swap lat/lon for frontend map)
97
+ nearest_beach_details = {
98
+ 'latitude': nearest_beach['Longitude'], # swap!
99
+ 'longitude': nearest_beach['Latitude'], # swap!
100
+ 'orientation': nearest_beach['Orientation'],
101
+ 'sediment': nearest_beach['Sediment']
102
+ }
103
+
104
+ return jsonify({
105
+ 'user_latitude': latitude, # Correct: real latitude
106
+ 'user_longitude': longitude, # Correct: real longitude
107
+ 'wind_direction': wind_dir,
108
+ 'wind_strength': wind_str,
109
+ 'prediction': prediction,
110
+ 'nearest_beach': nearest_beach_details,
111
+ 'success': True
112
+ })