ML-predictor / predict.py
philip-singer's picture
Update predict.py
8c89017 verified
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import joblib
from flask import Blueprint, request, jsonify
import os
predict_bp = Blueprint('predict', __name__)
# Lazy load model and preprocessors
model_bundle = None
model = None
preprocessor = None
nn_model = None
beach_db = None
wind_directions = None
class ImprovedTrashPredictorMLP(nn.Module):
def __init__(self, input_size, hidden_sizes, output_size, dropout_rate=0.3):
super().__init__()
layers = []
in_features = input_size
for hidden_size in hidden_sizes:
layers.append(nn.Linear(in_features, hidden_size))
layers.append(nn.BatchNorm1d(hidden_size))
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout_rate))
in_features = hidden_size
layers.append(nn.Linear(in_features, output_size))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
def load_bundle(bundle_path):
return {
"preprocessor": joblib.load(os.path.join(bundle_path, "preprocessor.pkl")),
"beach_db": joblib.load(os.path.join(bundle_path, "beach_db.pkl")),
"nn_model": joblib.load(os.path.join(bundle_path, "nn_model.pkl")),
"wind_directions": joblib.load(os.path.join(bundle_path, "wind_directions.pkl")),
}
def lazy_load():
global model_bundle, model, preprocessor, nn_model, beach_db, wind_directions
if model_bundle is None:
bundle_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'saved_models/final_bundle_20250703_112021'))
model_bundle = load_bundle(bundle_path)
preprocessor = model_bundle["preprocessor"]
beach_db = model_bundle["beach_db"]
nn_model = model_bundle["nn_model"]
wind_directions = model_bundle["wind_directions"]
# Model init
input_size = preprocessor.transformers_[0][1].steps[0][1].get_feature_names_out().shape[0] + len(preprocessor.transformers_[1][2])
model_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'saved_models/final_model_20250703_112021.pth'))
model_instance = ImprovedTrashPredictorMLP(
input_size=input_size,
hidden_sizes=[256, 128, 64, 32],
output_size=1,
dropout_rate=0.3
)
model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model_instance.eval()
model = model_instance
@predict_bp.route('/predict', methods=['POST'])
def predict():
lazy_load()
data = request.get_json()
latitude = float(data.get('latitude'))
longitude = float(data.get('longitude'))
wind_dir = data.get('wind_direction', wind_directions[0]).upper()
wind_str = float(data.get('wind_strength', 5))
# Find nearest beach
query_point = np.array([[latitude, longitude]])
_, indices = nn_model.kneighbors(query_point)
beach_index = indices[0][0]
nearest_beach = beach_db.iloc[beach_index]
# Create input
input_data = pd.DataFrame({
'Orientation': [nearest_beach['Orientation']],
'Sediment': [nearest_beach['Sediment']],
'Longitude': [longitude],
'Latitude': [latitude],
'Wind direction': [wind_dir],
'Wind strength': [wind_str]
})
# Preprocess and predict
processed_input = preprocessor.transform(input_data)
input_tensor = torch.tensor(processed_input, dtype=torch.float32)
with torch.no_grad():
prediction = model(input_tensor).item()
# Get nearest beach details for the response (swap lat/lon for frontend map)
nearest_beach_details = {
'latitude': nearest_beach['Longitude'], # swap!
'longitude': nearest_beach['Latitude'], # swap!
'orientation': nearest_beach['Orientation'],
'sediment': nearest_beach['Sediment']
}
return jsonify({
'user_latitude': latitude, # Correct: real latitude
'user_longitude': longitude, # Correct: real longitude
'wind_direction': wind_dir,
'wind_strength': wind_str,
'prediction': prediction,
'nearest_beach': nearest_beach_details,
'success': True
})