Antcar
Create app.py
42bfcc1 verified
raw
history blame
2.39 kB
import gradio as gr
import tensorflow as tf
import numpy as np
import os
# 1. Load the Model
# We expect the .keras file to be in the same directory
model = tf.keras.models.load_model("hk_transit_flow_net.keras")
# Helper to map day names to integers
DAY_MAP = {
"Sunday": 0, "Monday": 1, "Tuesday": 2, "Wednesday": 3,
"Thursday": 4, "Friday": 5, "Saturday": 6
}
def predict_eta(distance_meters, num_stops, hour, day_name, route_id):
try:
# 1. Prepare Inputs
# We must match the exact shape and types used in training
# Handle empty route
if not route_id or route_id.strip() == "":
route_id = "UNKNOWN"
inputs = {
'distance': np.array([[float(distance_meters)]]),
'num_stops': np.array([[float(num_stops)]]),
'hour': np.array([[int(hour)]]),
'day_of_week': np.array([[int(DAY_MAP[day_name])]]),
'route_id': tf.constant([[str(route_id)]], dtype=tf.string)
}
# 2. Run Prediction
prediction = model.predict(inputs, verbose=0)
seconds = float(prediction[0][0])
# 3. Format Output
minutes = int(seconds // 60)
rem_seconds = int(seconds % 60)
return f"{minutes} min {rem_seconds} sec ({seconds:.1f}s)"
except Exception as e:
return f"Error: {str(e)}"
# 3. Build the Interface
iface = gr.Interface(
fn=predict_eta,
inputs=[
gr.Number(label="Distance (meters)", value=5000),
gr.Number(label="Number of Stops", value=10),
gr.Slider(minimum=0, maximum=23, step=1, label="Hour of Day (0-23)", value=9),
gr.Dropdown(choices=list(DAY_MAP.keys()), label="Day of Week", value="Monday"),
gr.Textbox(label="Route ID (Optional)", placeholder="e.g. 968+1+Yuen Long+Tin Hau", value="UNKNOWN")
],
outputs="text",
title="HK-TransitFlow-Net ๐ŸšŒ",
description="""
**Hong Kong Bus ETA Predictor**
This model uses Deep Learning to predict bus travel time based on distance, stops, and time context.
* **Distance:** Physical distance of the path in meters.
* **Route ID:** Internal ID (e.g., `968+1+...`). If unknown, leave as UNKNOWN.
* **Note:** Trained on KMB & CTB data.
""",
theme="soft"
)
# 4. Launch
if __name__ == "__main__":
iface.launch()