Spaces:
Building
Building
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| import os | |
| # 1. Load the Model | |
| # We expect the .keras file to be in the same directory | |
| model = tf.keras.models.load_model("hk_transit_flow_net.keras") | |
| # Helper to map day names to integers | |
| DAY_MAP = { | |
| "Sunday": 0, "Monday": 1, "Tuesday": 2, "Wednesday": 3, | |
| "Thursday": 4, "Friday": 5, "Saturday": 6 | |
| } | |
| def predict_eta(distance_meters, num_stops, hour, day_name, route_id): | |
| try: | |
| # 1. Prepare Inputs | |
| # We must match the exact shape and types used in training | |
| # Handle empty route | |
| if not route_id or route_id.strip() == "": | |
| route_id = "UNKNOWN" | |
| inputs = { | |
| 'distance': np.array([[float(distance_meters)]]), | |
| 'num_stops': np.array([[float(num_stops)]]), | |
| 'hour': np.array([[int(hour)]]), | |
| 'day_of_week': np.array([[int(DAY_MAP[day_name])]]), | |
| 'route_id': tf.constant([[str(route_id)]], dtype=tf.string) | |
| } | |
| # 2. Run Prediction | |
| prediction = model.predict(inputs, verbose=0) | |
| seconds = float(prediction[0][0]) | |
| # 3. Format Output | |
| minutes = int(seconds // 60) | |
| rem_seconds = int(seconds % 60) | |
| return f"{minutes} min {rem_seconds} sec ({seconds:.1f}s)" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # 3. Build the Interface | |
| iface = gr.Interface( | |
| fn=predict_eta, | |
| inputs=[ | |
| gr.Number(label="Distance (meters)", value=5000), | |
| gr.Number(label="Number of Stops", value=10), | |
| gr.Slider(minimum=0, maximum=23, step=1, label="Hour of Day (0-23)", value=9), | |
| gr.Dropdown(choices=list(DAY_MAP.keys()), label="Day of Week", value="Monday"), | |
| gr.Textbox(label="Route ID (Optional)", placeholder="e.g. 968+1+Yuen Long+Tin Hau", value="UNKNOWN") | |
| ], | |
| outputs="text", | |
| title="HK-TransitFlow-Net ๐", | |
| description=""" | |
| **Hong Kong Bus ETA Predictor** | |
| This model uses Deep Learning to predict bus travel time based on distance, stops, and time context. | |
| * **Distance:** Physical distance of the path in meters. | |
| * **Route ID:** Internal ID (e.g., `968+1+...`). If unknown, leave as UNKNOWN. | |
| * **Note:** Trained on KMB & CTB data. | |
| """, | |
| theme="soft" | |
| ) | |
| # 4. Launch | |
| if __name__ == "__main__": | |
| iface.launch() |