Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import requests | |
| import plotly.graph_objects as go | |
| from geopy.geocoders import Nominatim | |
| import pandas as pd | |
| from datetime import datetime | |
| import holidays | |
| import numpy as np | |
| from sklearn.preprocessing import MinMaxScaler | |
| import pickle | |
| import xgboost as xgb | |
| # Setting up the page configuration for Streamlit App | |
| st.set_page_config( | |
| page_title="Taxi", | |
| # layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Load the XGBoost model | |
| #@st.cache_data() | |
| def get_model(): | |
| model = pickle.load(open("models/model_xgb.pkl", "rb")) | |
| return model | |
| # Function to make prediction using the model and input data | |
| def make_prediction(data): | |
| model = get_model() | |
| best_features = ['vendor_id', 'passenger_count', 'pickup_longitude', 'pickup_latitude', | |
| 'dropoff_longitude', 'dropoff_latitude', 'store_and_fwd_flag', | |
| 'pickup_hour', 'pickup_holiday', 'total_distance', 'total_travel_time', | |
| 'number_of_steps', 'haversine_distance', 'temperature', | |
| 'pickup_day_of_week_1', 'pickup_day_of_week_2', 'pickup_day_of_week_3', | |
| 'pickup_day_of_week_4', 'pickup_day_of_week_5', 'pickup_day_of_week_6', | |
| 'geo_cluster_1', 'geo_cluster_3', 'geo_cluster_5', 'geo_cluster_7', | |
| 'geo_cluster_9'] | |
| data_matrix = xgb.DMatrix(data, feature_names=best_features) | |
| return model.predict(data_matrix) | |
| # Get coordinates from address | |
| def get_coordinates(address): | |
| geolocator = Nominatim(user_agent="my_app") | |
| location = geolocator.geocode(address) | |
| return (location.longitude, location.latitude) | |
| def show_map(lon_from, lat_from, lon_to, lat_to): | |
| # Creating a map | |
| fig = go.Figure(go.Scattermapbox( | |
| mode = "markers", | |
| marker = {'size': 15, 'color': 'red'} | |
| )) | |
| # Adding markers | |
| fig.add_trace(go.Scattermapbox( | |
| mode = "markers", | |
| lon = [lon_from, lon_to], | |
| lat = [lat_from, lat_to], | |
| marker = go.scattermapbox.Marker( | |
| size=25, | |
| color='red' | |
| ) | |
| )) | |
| # Adding a line | |
| fig.add_trace(go.Scattermapbox( | |
| mode = "lines", | |
| lon = [lon_from, lon_to], | |
| lat = [lat_from, lat_to], | |
| line = dict(width=2, color='green') | |
| )) | |
| # Configuring the display of a map | |
| fig.update_layout( | |
| mapbox = { | |
| 'style': "open-street-map", | |
| 'center': {'lon': (lon_from + lon_to) / 2, 'lat': (lat_from + lat_to) / 2}, | |
| 'zoom': 9, | |
| }, | |
| showlegend = False, | |
| height = 600, | |
| width = 1200 | |
| ) | |
| # Display the map | |
| return fig | |
| # Get total distance | |
| def get_total_distance(start_longitude, start_latitude, end_longitude, end_latitude): | |
| # Construct the URL for sending a request to the public OSRM server | |
| url = f"http://router.project-osrm.org/route/v1/driving/{start_longitude},{start_latitude};{end_longitude},{end_latitude}?overview=false" | |
| # Send a GET request to the OSRM server | |
| response = requests.get(url) | |
| # Process the response from the server | |
| if response.status_code == 200: | |
| data = response.json() | |
| total_distance = data["routes"][0]["distance"] # Total distance in meters | |
| total_travel_time = data["routes"][0]["duration"] # Total travel time in seconds | |
| number_of_steps = len(data["routes"][0]["legs"][0]["steps"]) # Number of steps in the | |
| return total_distance, total_travel_time, number_of_steps | |
| # Get Harversine distance | |
| def get_haversine_distance(lat1, lng1, lat2, lng2): | |
| # Convert angles to radians | |
| lat1, lng1, lat2, lng2 = map(np.radians, (lat1, lng1, lat2, lng2)) | |
| # Earth's radius in kilometers | |
| EARTH_RADIUS = 6371 | |
| # Calculate the shortest distance h using the Haversine formula | |
| lat_delta = lat2 - lat1 | |
| lng_delta = lng2 - lng1 | |
| d = np.sin(lat_delta * 0.5) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(lng_delta * 0.5) ** 2 | |
| h = 2 * EARTH_RADIUS * np.arcsin(np.sqrt(d)) | |
| return h | |
| # User input features | |
| def user_input_features(lon_from, lat_from, lon_to, lat_to, passenger_count): | |
| current_time = datetime.now() | |
| pickup_hour= current_time.hour | |
| today = datetime.today() | |
| pickup_holiday = 1 if today in holidays.USA() else 0 | |
| total_distance, total_travel_time, number_of_steps = get_total_distance(lon_from, lat_from, lon_to, lat_to) | |
| haversine_distance = get_haversine_distance(lat_from, lon_from, lat_to, lon_to) | |
| weekday_number = current_time.weekday() | |
| data = {'vendor_id': 1, | |
| 'passenger_count': passenger_count, | |
| 'pickup_longitude': lon_from, | |
| 'pickup_latitude': lat_from, | |
| 'dropoff_longitude': lon_to, | |
| 'dropoff_latitude': lat_to, | |
| 'store_and_fwd_flag': 0.0, | |
| 'pickup_hour': pickup_hour, | |
| 'pickup_holiday': pickup_holiday, | |
| 'total_distance': total_distance, | |
| 'total_travel_time': total_travel_time, | |
| 'number_of_steps': number_of_steps, | |
| 'haversine_distance': haversine_distance, | |
| 'temperature': 15, | |
| 'pickup_day_of_week_1': 1 if weekday_number == 1 else 0, | |
| 'pickup_day_of_week_2': 1 if weekday_number == 2 else 0, | |
| 'pickup_day_of_week_3': 1 if weekday_number == 3 else 0, | |
| 'pickup_day_of_week_4': 1 if weekday_number == 4 else 0, | |
| 'pickup_day_of_week_5': 1 if weekday_number == 5 else 0, | |
| 'pickup_day_of_week_6': 1 if weekday_number == 6 else 0, | |
| 'geo_cluster_1':1, | |
| 'geo_cluster_3':0, | |
| 'geo_cluster_5':0, | |
| 'geo_cluster_7':0, | |
| 'geo_cluster_9':0 | |
| } | |
| features = pd.DataFrame(data, index=[0]) | |
| return features | |
| # Scale the input data using a pre-trained MinMaxScaler | |
| def min_max_scaler(data): | |
| scaler = pickle.load(open("models/min_max_scaler.pkl", "rb")) | |
| data_scaled = scaler.transform(data) | |
| return data_scaled | |
| # Main function | |
| def main(): | |
| if 'btn_predict' not in st.session_state: | |
| st.session_state['btn_predict'] = False | |
| # Sidebar | |
| st.sidebar.markdown(''' # New York City Taxi Trip Duration''') | |
| st.sidebar.image("img/taxi_img.png") | |
| address_from = st.sidebar.text_input("Откуда:", value="New York, 11 Wall Street") | |
| address_to = st.sidebar.text_input("Куда:", value="New York, 740 Park Avenue") | |
| passenger_count = st.sidebar.slider("Количество пассажиров", 1, 4, 1) | |
| st.session_state['btn_predict'] = st.sidebar.button('Start') | |
| if st.session_state['btn_predict']: | |
| lon_from, lat_from = get_coordinates(address_from) | |
| lon_to, lat_to = get_coordinates(address_to) | |
| st.plotly_chart(show_map(lon_from, lat_from, lon_to, lat_to)) | |
| user_data = user_input_features(lon_from, lat_from, lon_to, lat_to, passenger_count) | |
| # st.write(user_data) | |
| data_scaled = min_max_scaler(user_data) | |
| trip_duration = np.exp(make_prediction(data_scaled)) - 1 | |
| trip_duration = round(float(trip_duration) / 60) | |
| st.markdown(f""" | |
| <div style='background-color: lightgreen; padding: 10px;'> | |
| <h2 style='color: black; text-align: center;'>Длительность поездки составит: {trip_duration} мин.</h2> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Running the main function | |
| if __name__ == "__main__": | |
| main() | |