HiCoTraj / app.py
ginnyxxxxxxx's picture
map
0f3999b
raw
history blame
4.79 kB
import gradio as gr
import pandas as pd
import folium
import os
# ── Data paths ───────────────────────────────────────────────────────────────
BASE = os.path.dirname(os.path.abspath(__file__))
STAY_POINTS = os.path.join(BASE, "data", "stay_points_sampled.csv")
POI_PATH = os.path.join(BASE, "data", "poi_sampled.csv")
DEMO_PATH = os.path.join(BASE, "data", "demographics_sampled.csv")
AGE_MAP = {1:"<18", 2:"18-24", 3:"25-34", 4:"35-44", 5:"45-54", 6:"55-64", 7:"65+"}
SEX_MAP = {1:"Male", 2:"Female"}
EDU_MAP = {1:"No HS", 2:"HS Grad", 3:"Some College", 4:"Bachelor's", 5:"Graduate"}
INC_MAP = {1:"<$10k", 2:"$10-15k", 3:"$15-25k", 4:"$25-35k", 5:"$35-50k",
6:"$50-75k", 7:"$75-100k", 8:"$100-125k", 9:"$125-150k", 10:">$150k"}
# ── Load ─────────────────────────────────────────────────────────────────────
print("Loading data...")
sp = pd.read_csv(STAY_POINTS)
poi = pd.read_csv(POI_PATH)
demo = pd.read_csv(DEMO_PATH)
sp = sp.merge(poi, on="poi_id", how="left")
sp["start_datetime"] = pd.to_datetime(sp["start_datetime"], utc=True)
sp["end_datetime"] = pd.to_datetime(sp["end_datetime"], utc=True)
sp["duration_min"] = ((sp["end_datetime"] - sp["start_datetime"]).dt.total_seconds() / 60).round(1)
sample_agents = sorted(sp["agent_id"].unique().tolist())
print(f"Ready. {len(sample_agents)} agents loaded.")
# ── Helpers ───────────────────────────────────────────────────────────────────
def build_map(agent_sp):
lat = agent_sp["latitude"].mean()
lon = agent_sp["longitude"].mean()
m = folium.Map(location=[lat, lon], zoom_start=12, tiles="CartoDB positron")
coords = list(zip(agent_sp["latitude"], agent_sp["longitude"]))
if len(coords) > 1:
folium.PolyLine(coords, color="#4f86c6", weight=2, opacity=0.6).add_to(m)
for _, row in agent_sp.iterrows():
folium.CircleMarker(
location=[row["latitude"], row["longitude"]],
radius=6, color="#e05c5c", fill=True, fill_opacity=0.8,
popup=folium.Popup(
f"<b>{row['name']}</b><br>"
f"{row['start_datetime'].strftime('%a %m/%d %H:%M')}<br>"
f"{int(row['duration_min'])} min",
max_width=200
)
).add_to(m)
# return m.get_root().render()
m.get_root().width = "100%"
m.get_root().height = "500px"
return m._repr_html_()
def build_poi_sequence(agent_sp):
lines = []
for _, row in agent_sp.iterrows():
lines.append(
f"{row['start_datetime'].strftime('%a %m/%d')} "
f"{row['start_datetime'].strftime('%H:%M')}–{row['end_datetime'].strftime('%H:%M')} "
f"({int(row['duration_min'])} min) | {row['name']} | {row['act_types']}"
)
return "\n".join(lines)
def build_demo_text(row):
return (
f"Age: {AGE_MAP.get(row['age'], row['age'])} | "
f"Sex: {SEX_MAP.get(row['sex'], row['sex'])} | "
f"Education: {EDU_MAP.get(row['education'], row['education'])} | "
f"Income: {INC_MAP.get(row['hh_income'], row['hh_income'])}"
)
def on_select(agent_id):
agent_id = int(agent_id)
agent_sp = sp[sp["agent_id"] == agent_id].sort_values("start_datetime")
agent_demo = demo[demo["agent_id"] == agent_id].iloc[0]
return build_map(agent_sp), build_poi_sequence(agent_sp), build_demo_text(agent_demo)
# ── UI ────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="HiCoTraj Demo", theme=gr.themes.Soft()) as app:
gr.Markdown("## HiCoTraj: Trajectory Visualization")
with gr.Row():
agent_dd = gr.Dropdown(choices=[str(a) for a in sample_agents],
label="Select Agent", value=str(sample_agents[0]))
demo_label = gr.Textbox(label="Ground Truth Demographics", interactive=False)
map_out = gr.HTML(label="Trajectory Map", value="<div style='height:500px'></div>")
poi_out = gr.Textbox(label="POI Sequence", lines=20, interactive=False)
agent_dd.change(fn=on_select, inputs=agent_dd, outputs=[map_out, poi_out, demo_label])
app.load(fn=on_select, inputs=agent_dd, outputs=[map_out, poi_out, demo_label])
if __name__ == "__main__":
app.launch()