philip-singer commited on
Commit
8c89017
·
verified ·
1 Parent(s): 24f7b69

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +111 -111
predict.py CHANGED
@@ -1,112 +1,112 @@
1
- import torch
2
- import torch.nn as nn
3
- import numpy as np
4
- import pandas as pd
5
- import joblib
6
- from flask import Blueprint, request, jsonify
7
- import os
8
-
9
- predict_bp = Blueprint('predict', __name__)
10
-
11
- # Lazy load model and preprocessors
12
- model_bundle = None
13
- model = None
14
- preprocessor = None
15
- nn_model = None
16
- beach_db = None
17
- wind_directions = None
18
-
19
- class ImprovedTrashPredictorMLP(nn.Module):
20
- def __init__(self, input_size, hidden_sizes, output_size, dropout_rate=0.3):
21
- super().__init__()
22
- layers = []
23
- in_features = input_size
24
- for hidden_size in hidden_sizes:
25
- layers.append(nn.Linear(in_features, hidden_size))
26
- layers.append(nn.BatchNorm1d(hidden_size))
27
- layers.append(nn.ReLU())
28
- layers.append(nn.Dropout(dropout_rate))
29
- in_features = hidden_size
30
- layers.append(nn.Linear(in_features, output_size))
31
- self.model = nn.Sequential(*layers)
32
- def forward(self, x):
33
- return self.model(x)
34
-
35
- def load_bundle(bundle_path):
36
- return {
37
- "preprocessor": joblib.load(os.path.join(bundle_path, "preprocessor.pkl")),
38
- "beach_db": joblib.load(os.path.join(bundle_path, "beach_db.pkl")),
39
- "nn_model": joblib.load(os.path.join(bundle_path, "nn_model.pkl")),
40
- "wind_directions": joblib.load(os.path.join(bundle_path, "wind_directions.pkl")),
41
- }
42
-
43
- def lazy_load():
44
- global model_bundle, model, preprocessor, nn_model, beach_db, wind_directions
45
- if model_bundle is None:
46
- bundle_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../saved_models/final_bundle_20250703_112021'))
47
- model_bundle = load_bundle(bundle_path)
48
- preprocessor = model_bundle["preprocessor"]
49
- beach_db = model_bundle["beach_db"]
50
- nn_model = model_bundle["nn_model"]
51
- wind_directions = model_bundle["wind_directions"]
52
- # Model init
53
- input_size = preprocessor.transformers_[0][1].steps[0][1].get_feature_names_out().shape[0] + len(preprocessor.transformers_[1][2])
54
- model_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../saved_models/final_model_20250703_112021.pth'))
55
- model_instance = ImprovedTrashPredictorMLP(
56
- input_size=input_size,
57
- hidden_sizes=[256, 128, 64, 32],
58
- output_size=1,
59
- dropout_rate=0.3
60
- )
61
- model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
62
- model_instance.eval()
63
- model = model_instance
64
-
65
- @predict_bp.route('/predict', methods=['POST'])
66
- def predict():
67
- lazy_load()
68
- data = request.get_json()
69
- latitude = float(data.get('latitude'))
70
- longitude = float(data.get('longitude'))
71
- wind_dir = data.get('wind_direction', wind_directions[0]).upper()
72
- wind_str = float(data.get('wind_strength', 5))
73
-
74
- # Find nearest beach
75
- query_point = np.array([[latitude, longitude]])
76
- _, indices = nn_model.kneighbors(query_point)
77
- beach_index = indices[0][0]
78
- nearest_beach = beach_db.iloc[beach_index]
79
-
80
- # Create input
81
- input_data = pd.DataFrame({
82
- 'Orientation': [nearest_beach['Orientation']],
83
- 'Sediment': [nearest_beach['Sediment']],
84
- 'Longitude': [longitude],
85
- 'Latitude': [latitude],
86
- 'Wind direction': [wind_dir],
87
- 'Wind strength': [wind_str]
88
- })
89
-
90
- # Preprocess and predict
91
- processed_input = preprocessor.transform(input_data)
92
- input_tensor = torch.tensor(processed_input, dtype=torch.float32)
93
- with torch.no_grad():
94
- prediction = model(input_tensor).item()
95
-
96
- # Get nearest beach details for the response (swap lat/lon for frontend map)
97
- nearest_beach_details = {
98
- 'latitude': nearest_beach['Longitude'], # swap!
99
- 'longitude': nearest_beach['Latitude'], # swap!
100
- 'orientation': nearest_beach['Orientation'],
101
- 'sediment': nearest_beach['Sediment']
102
- }
103
-
104
- return jsonify({
105
- 'user_latitude': latitude, # Correct: real latitude
106
- 'user_longitude': longitude, # Correct: real longitude
107
- 'wind_direction': wind_dir,
108
- 'wind_strength': wind_str,
109
- 'prediction': prediction,
110
- 'nearest_beach': nearest_beach_details,
111
- 'success': True
112
  })
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import pandas as pd
5
+ import joblib
6
+ from flask import Blueprint, request, jsonify
7
+ import os
8
+
9
+ predict_bp = Blueprint('predict', __name__)
10
+
11
+ # Lazy load model and preprocessors
12
+ model_bundle = None
13
+ model = None
14
+ preprocessor = None
15
+ nn_model = None
16
+ beach_db = None
17
+ wind_directions = None
18
+
19
+ class ImprovedTrashPredictorMLP(nn.Module):
20
+ def __init__(self, input_size, hidden_sizes, output_size, dropout_rate=0.3):
21
+ super().__init__()
22
+ layers = []
23
+ in_features = input_size
24
+ for hidden_size in hidden_sizes:
25
+ layers.append(nn.Linear(in_features, hidden_size))
26
+ layers.append(nn.BatchNorm1d(hidden_size))
27
+ layers.append(nn.ReLU())
28
+ layers.append(nn.Dropout(dropout_rate))
29
+ in_features = hidden_size
30
+ layers.append(nn.Linear(in_features, output_size))
31
+ self.model = nn.Sequential(*layers)
32
+ def forward(self, x):
33
+ return self.model(x)
34
+
35
+ def load_bundle(bundle_path):
36
+ return {
37
+ "preprocessor": joblib.load(os.path.join(bundle_path, "preprocessor.pkl")),
38
+ "beach_db": joblib.load(os.path.join(bundle_path, "beach_db.pkl")),
39
+ "nn_model": joblib.load(os.path.join(bundle_path, "nn_model.pkl")),
40
+ "wind_directions": joblib.load(os.path.join(bundle_path, "wind_directions.pkl")),
41
+ }
42
+
43
+ def lazy_load():
44
+ global model_bundle, model, preprocessor, nn_model, beach_db, wind_directions
45
+ if model_bundle is None:
46
+ bundle_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'saved_models/final_bundle_20250703_112021'))
47
+ model_bundle = load_bundle(bundle_path)
48
+ preprocessor = model_bundle["preprocessor"]
49
+ beach_db = model_bundle["beach_db"]
50
+ nn_model = model_bundle["nn_model"]
51
+ wind_directions = model_bundle["wind_directions"]
52
+ # Model init
53
+ input_size = preprocessor.transformers_[0][1].steps[0][1].get_feature_names_out().shape[0] + len(preprocessor.transformers_[1][2])
54
+ model_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'saved_models/final_model_20250703_112021.pth'))
55
+ model_instance = ImprovedTrashPredictorMLP(
56
+ input_size=input_size,
57
+ hidden_sizes=[256, 128, 64, 32],
58
+ output_size=1,
59
+ dropout_rate=0.3
60
+ )
61
+ model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
62
+ model_instance.eval()
63
+ model = model_instance
64
+
65
+ @predict_bp.route('/predict', methods=['POST'])
66
+ def predict():
67
+ lazy_load()
68
+ data = request.get_json()
69
+ latitude = float(data.get('latitude'))
70
+ longitude = float(data.get('longitude'))
71
+ wind_dir = data.get('wind_direction', wind_directions[0]).upper()
72
+ wind_str = float(data.get('wind_strength', 5))
73
+
74
+ # Find nearest beach
75
+ query_point = np.array([[latitude, longitude]])
76
+ _, indices = nn_model.kneighbors(query_point)
77
+ beach_index = indices[0][0]
78
+ nearest_beach = beach_db.iloc[beach_index]
79
+
80
+ # Create input
81
+ input_data = pd.DataFrame({
82
+ 'Orientation': [nearest_beach['Orientation']],
83
+ 'Sediment': [nearest_beach['Sediment']],
84
+ 'Longitude': [longitude],
85
+ 'Latitude': [latitude],
86
+ 'Wind direction': [wind_dir],
87
+ 'Wind strength': [wind_str]
88
+ })
89
+
90
+ # Preprocess and predict
91
+ processed_input = preprocessor.transform(input_data)
92
+ input_tensor = torch.tensor(processed_input, dtype=torch.float32)
93
+ with torch.no_grad():
94
+ prediction = model(input_tensor).item()
95
+
96
+ # Get nearest beach details for the response (swap lat/lon for frontend map)
97
+ nearest_beach_details = {
98
+ 'latitude': nearest_beach['Longitude'], # swap!
99
+ 'longitude': nearest_beach['Latitude'], # swap!
100
+ 'orientation': nearest_beach['Orientation'],
101
+ 'sediment': nearest_beach['Sediment']
102
+ }
103
+
104
+ return jsonify({
105
+ 'user_latitude': latitude, # Correct: real latitude
106
+ 'user_longitude': longitude, # Correct: real longitude
107
+ 'wind_direction': wind_dir,
108
+ 'wind_strength': wind_str,
109
+ 'prediction': prediction,
110
+ 'nearest_beach': nearest_beach_details,
111
+ 'success': True
112
  })