Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import pandas as pd | |
| import joblib | |
| from flask import Blueprint, request, jsonify | |
| import os | |
| predict_bp = Blueprint('predict', __name__) | |
| # Lazy load model and preprocessors | |
| model_bundle = None | |
| model = None | |
| preprocessor = None | |
| nn_model = None | |
| beach_db = None | |
| wind_directions = None | |
| class ImprovedTrashPredictorMLP(nn.Module): | |
| def __init__(self, input_size, hidden_sizes, output_size, dropout_rate=0.3): | |
| super().__init__() | |
| layers = [] | |
| in_features = input_size | |
| for hidden_size in hidden_sizes: | |
| layers.append(nn.Linear(in_features, hidden_size)) | |
| layers.append(nn.BatchNorm1d(hidden_size)) | |
| layers.append(nn.ReLU()) | |
| layers.append(nn.Dropout(dropout_rate)) | |
| in_features = hidden_size | |
| layers.append(nn.Linear(in_features, output_size)) | |
| self.model = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.model(x) | |
| def load_bundle(bundle_path): | |
| return { | |
| "preprocessor": joblib.load(os.path.join(bundle_path, "preprocessor.pkl")), | |
| "beach_db": joblib.load(os.path.join(bundle_path, "beach_db.pkl")), | |
| "nn_model": joblib.load(os.path.join(bundle_path, "nn_model.pkl")), | |
| "wind_directions": joblib.load(os.path.join(bundle_path, "wind_directions.pkl")), | |
| } | |
| def lazy_load(): | |
| global model_bundle, model, preprocessor, nn_model, beach_db, wind_directions | |
| if model_bundle is None: | |
| bundle_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'saved_models/final_bundle_20250703_112021')) | |
| model_bundle = load_bundle(bundle_path) | |
| preprocessor = model_bundle["preprocessor"] | |
| beach_db = model_bundle["beach_db"] | |
| nn_model = model_bundle["nn_model"] | |
| wind_directions = model_bundle["wind_directions"] | |
| # Model init | |
| input_size = preprocessor.transformers_[0][1].steps[0][1].get_feature_names_out().shape[0] + len(preprocessor.transformers_[1][2]) | |
| model_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'saved_models/final_model_20250703_112021.pth')) | |
| model_instance = ImprovedTrashPredictorMLP( | |
| input_size=input_size, | |
| hidden_sizes=[256, 128, 64, 32], | |
| output_size=1, | |
| dropout_rate=0.3 | |
| ) | |
| model_instance.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| model_instance.eval() | |
| model = model_instance | |
| def predict(): | |
| lazy_load() | |
| data = request.get_json() | |
| latitude = float(data.get('latitude')) | |
| longitude = float(data.get('longitude')) | |
| wind_dir = data.get('wind_direction', wind_directions[0]).upper() | |
| wind_str = float(data.get('wind_strength', 5)) | |
| # Find nearest beach | |
| query_point = np.array([[latitude, longitude]]) | |
| _, indices = nn_model.kneighbors(query_point) | |
| beach_index = indices[0][0] | |
| nearest_beach = beach_db.iloc[beach_index] | |
| # Create input | |
| input_data = pd.DataFrame({ | |
| 'Orientation': [nearest_beach['Orientation']], | |
| 'Sediment': [nearest_beach['Sediment']], | |
| 'Longitude': [longitude], | |
| 'Latitude': [latitude], | |
| 'Wind direction': [wind_dir], | |
| 'Wind strength': [wind_str] | |
| }) | |
| # Preprocess and predict | |
| processed_input = preprocessor.transform(input_data) | |
| input_tensor = torch.tensor(processed_input, dtype=torch.float32) | |
| with torch.no_grad(): | |
| prediction = model(input_tensor).item() | |
| # Get nearest beach details for the response (swap lat/lon for frontend map) | |
| nearest_beach_details = { | |
| 'latitude': nearest_beach['Longitude'], # swap! | |
| 'longitude': nearest_beach['Latitude'], # swap! | |
| 'orientation': nearest_beach['Orientation'], | |
| 'sediment': nearest_beach['Sediment'] | |
| } | |
| return jsonify({ | |
| 'user_latitude': latitude, # Correct: real latitude | |
| 'user_longitude': longitude, # Correct: real longitude | |
| 'wind_direction': wind_dir, | |
| 'wind_strength': wind_str, | |
| 'prediction': prediction, | |
| 'nearest_beach': nearest_beach_details, | |
| 'success': True | |
| }) |